Skip to content

Commit 5b7e3b2

Browse files
authored
Merge pull request #1 from ckl117/logprobs
infer engine support base logprobs
2 parents c8cdb94 + 74698f0 commit 5b7e3b2

File tree

8 files changed

+164
-48
lines changed

8 files changed

+164
-48
lines changed

custom_ops/gpu_ops/get_output_msg_with_topk.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#endif
2525

2626
#define MAX_BSZ 512
27-
#define K 10
27+
#define K 20
2828

2929
struct msgdata {
3030
long mtype;

custom_ops/gpu_ops/save_output_msg_with_topk.cc

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
2424
#endif
2525

26-
#define MAX_BSZ 128
27-
#define K 10
26+
#define MAX_BSZ 512
27+
#define K 20
2828
// #define SAVE_WITH_OUTPUT_DEBUG
2929

3030
struct msgdata {
@@ -35,22 +35,15 @@ struct msgdata {
3535

3636
void SaveOutMmsgTopK(const paddle::Tensor& x,
3737
const paddle::Tensor& scores,
38-
const paddle::Tensor& topk_ids,
39-
const paddle::Tensor& topk_scores, // [bsz, k]
4038
const paddle::Tensor& not_need_stop,
41-
int k,
4239
int64_t rank_id) {
4340
if (rank_id > 0) {
4441
return;
4542
}
4643
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
4744
auto scores_cpu = scores.copy_to(paddle::CPUPlace(), false);
48-
auto topk_ids_cpu = topk_ids.copy_to(paddle::CPUPlace(), false);
49-
auto topk_scores_cpu = topk_scores.copy_to(paddle::CPUPlace(), false);
5045
int64_t* x_data = x_cpu.data<int64_t>();
5146
float* scores_data = scores_cpu.data<float>();
52-
int64_t* topk_ids_data = topk_ids_cpu.data<int64_t>();
53-
float* topk_scores_data = topk_scores_cpu.data<float>();
5447
static struct msgdata msg_sed;
5548
int msg_queue_id = 1;
5649
if (const char* inference_msg_queue_id_env_p =
@@ -106,20 +99,14 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
10699
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
107100
: -inference_msg_id_from_env;
108101
int bsz = x.shape()[0];
102+
int token_num = x.shape()[1];
103+
int k = token_num - 1;
109104
msg_sed.mtext[1] = bsz;
110105
for (int i = 0; i < bsz; i++) {
111-
for (int j = 0; j < k + 1; j++) {
106+
for (int j = 0; j < token_num; j++) {
112107
const int64_t offset = i * (K + 1) + j;
113-
if (j == 0) {
114-
msg_sed.mtext[offset + 2] = (int)x_data[i];
115-
msg_sed.mtext_f[offset] = scores_data[i];
116-
} else if (j <= k + 1) {
117-
msg_sed.mtext[offset + 2] = (int)topk_ids_data[i * k + j - 1];
118-
msg_sed.mtext_f[offset] = topk_scores_data[i * k + j - 1];
119-
} else {
120-
msg_sed.mtext[offset + 2] = -1;
121-
msg_sed.mtext_f[offset] = 0.0;
122-
}
108+
msg_sed.mtext[offset + 2] = (int)x_data[i * token_num + j];
109+
msg_sed.mtext_f[offset] = scores_data[i * token_num + j];
123110
}
124111
}
125112
#ifdef SAVE_WITH_OUTPUT_DEBUG
@@ -139,8 +126,8 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
139126
}
140127

141128
PD_BUILD_STATIC_OP(save_output_topk)
142-
.Inputs({"x", "scores", "topk_ids", "topk_scores", "not_need_stop"})
143-
.Attrs({"k: int", "rank_id: int64_t"})
129+
.Inputs({"x", "scores", "not_need_stop"})
130+
.Attrs({"rank_id: int64_t"})
144131
.Outputs({"x_out"})
145132
.SetInplaceMap({{"x", "x_out"}})
146133
.SetKernelFn(PD_KERNEL(SaveOutMmsgTopK));

custom_ops/setup_ops_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"gpu_ops/save_with_output_msg.cc",
2323
"gpu_ops/get_output.cc",
2424
"gpu_ops/get_output_msg_with_topk.cc",
25+
"gpu_ops/save_output_msg_with_topk.cc",
2526
"gpu_ops/transfer_output.cc",
2627
"cpu_ops/rebuild_padding.cc",
2728
],

fastdeploy/model_executor/layers/sample/meta_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ class SamplingMetadata:
4242

4343
top_p: paddle.Tensor
4444
top_k: Optional[paddle.Tensor] = None
45+
max_num_logprobs: Optional[int] = None

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
3030
top_p_sampling)
3131
from fastdeploy.platforms import current_platform
32+
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
3233

3334

3435
class SamplerProcessor:
@@ -189,14 +190,65 @@ def pre_process(self, skip_idx_list: List[int] = []):
189190
""" pre process before running """
190191
self.processor.pre_process(skip_idx_list)
191192

193+
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
194+
"""
195+
"""
196+
return F.log_softmax(logits, axis=-1)
197+
198+
def gather_logprobs(
199+
self,
200+
logprobs: paddle.Tensor,
201+
num_logprobs: int,
202+
token_ids: paddle.Tensor,
203+
) -> LogprobsTensors:
204+
"""
205+
Gather logprobs for topk and sampled/prompt token.
206+
207+
Args:
208+
logprobs: (num tokens) x (vocab) tensor
209+
num_logprobs: minimum number of logprobs to
210+
retain per token
211+
token_ids: prompt tokens (if prompt logprobs)
212+
or sampled tokens (if sampled
213+
logprobs); 1D token ID tensor
214+
with (num tokens) elements
215+
Must be int64.
216+
217+
Returns:
218+
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
219+
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
220+
Sampled token rank tensor, (num tokens)
221+
"""
222+
assert token_ids.dtype == paddle.int64
223+
# Find the topK values.
224+
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
225+
if num_logprobs >= 1:
226+
topk_logprobs, topk_indices = paddle.topk(logprobs,
227+
num_logprobs,
228+
axis=-1)
229+
indices = paddle.concat([token_ids, topk_indices], axis=1)
230+
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
231+
else:
232+
indices = token_ids
233+
top_logprobs = token_logprobs
234+
235+
# Compute the ranks of the actual token.
236+
token_ranks = (logprobs >= token_logprobs).sum(-1)
237+
238+
return LogprobsTensors(indices, top_logprobs, token_ranks)
239+
192240
def forward_cuda(
193241
self,
194242
logits: paddle.Tensor,
195243
sampling_metadata: SamplingMetadata,
196244
skip_idx_list: List[int] = [],
197-
) -> paddle.Tensor:
245+
) -> SamplerOutput:
198246
"""
199247
"""
248+
num_logprobs = sampling_metadata.max_num_logprobs
249+
if num_logprobs is not None:
250+
raw_logprobs = self.compute_logprobs(logits)
251+
200252
logits = self.processor.apply_token_mask(logits, skip_idx_list)
201253

202254
logits = apply_penalty_multi_scores(
@@ -216,8 +268,19 @@ def forward_cuda(
216268

217269
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
218270

271+
logprobs_tensors = None if num_logprobs is None else \
272+
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
273+
219274
self.processor.update_output_tokens(next_tokens, skip_idx_list)
220-
return next_tokens
275+
276+
sampler_output = SamplerOutput(
277+
# The sampled tokens are expanded to 2D tensor with shape
278+
# [num_requests, 1], where each row represents one generated
279+
# token per request.
280+
sampled_token_ids=next_tokens,
281+
logprobs_tensors=logprobs_tensors,
282+
)
283+
return sampler_output
221284

222285

223286
class SpeculativeSampler(nn.Layer):

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
speculate_save_output, speculate_set_value_by_flags_and_idx,
3333
speculate_step_paddle, speculate_step_system_cache,
3434
speculate_update_v3, step_paddle, step_system_cache, update_inputs,
35-
step_reschedule)
36-
from fastdeploy.worker.output import ModelOutputData
35+
step_reschedule, save_output_topk)
36+
from fastdeploy.worker.output import (ModelOutputData, ModelRunnerOutput,
37+
SamplerOutput)
3738

3839
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
3940

@@ -109,10 +110,10 @@ def pre_process(
109110
cu_seqlens_k, output_cum_offsets, output_padding_offset)
110111

111112

112-
def post_process_normal(sampled_token_ids: paddle.Tensor,
113+
def post_process_normal(sampler_output: SamplerOutput,
113114
model_output: ModelOutputData,
114115
save_each_rank: bool = False,
115-
skip_save_output: bool = False) -> None:
116+
skip_save_output: bool = False) -> ModelRunnerOutput:
116117
""" Post-processing steps after completing a single token generation. """
117118
# 1. Set stop value
118119
paddle.assign(
@@ -130,7 +131,8 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
130131
model_output.stop_flags,
131132
)
132133
# TODO(gongshaotian): Add use_stop_seqs
133-
set_stop_value_multi_ends(sampled_token_ids, model_output.stop_flags,
134+
set_stop_value_multi_ends(sampler_output.sampled_token_ids,
135+
model_output.stop_flags,
134136
model_output.seq_lens_this_time,
135137
model_output.eos_token_id,
136138
model_output.next_tokens, False) # multi ends
@@ -145,18 +147,26 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
145147
model_output.seq_lens_decoder,
146148
model_output.input_ids,
147149
model_output.stop_nums,
148-
sampled_token_ids,
150+
sampler_output.sampled_token_ids,
149151
model_output.is_block_step,
150152
)
151153
# 3. Transmit the model's output and stop generation signal via message queue.
152154
# In the future, we will abandon this approach.
153155
if not skip_save_output:
154-
save_output(
155-
sampled_token_ids,
156-
model_output.not_need_stop,
157-
model_output.mp_rank,
158-
save_each_rank, # save_each_rank
159-
)
156+
if sampler_output.logprobs_tensors is None:
157+
save_output(
158+
sampler_output.sampled_token_ids,
159+
model_output.not_need_stop,
160+
model_output.mp_rank,
161+
save_each_rank, # save_each_rank
162+
)
163+
else:
164+
save_output_topk(
165+
sampler_output.logprobs_tensors.logprob_token_ids,
166+
sampler_output.logprobs_tensors.logprobs,
167+
model_output.not_need_stop,
168+
model_output.mp_rank,
169+
)
160170

161171

162172
def post_process_specualate(model_output, skip_save_output: bool = False):
@@ -201,7 +211,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
201211
)
202212

203213

204-
def post_process(sampled_token_ids: paddle.Tensor,
214+
def post_process(sampler_output: SamplerOutput,
205215
model_output: ModelOutputData,
206216
save_each_rank: bool = False,
207217
speculative_decoding: bool = False,
@@ -210,7 +220,7 @@ def post_process(sampled_token_ids: paddle.Tensor,
210220
if speculative_decoding:
211221
post_process_specualate(model_output, skip_save_output)
212222
else:
213-
post_process_normal(sampled_token_ids, model_output, save_each_rank,
223+
post_process_normal(sampler_output, model_output, save_each_rank,
214224
skip_save_output)
215225

216226

fastdeploy/worker/gpu_model_runner.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ def _prepare_inputs(self) -> None:
582582
min_dec_lens=self.share_inputs["min_dec_len"],
583583
bad_words_token_ids=self.share_inputs["bad_tokens"],
584584
eos_token_ids=self.share_inputs["eos_token_id"],
585+
max_num_logprobs=None,
585586
)
586587

587588
def load_model(self) -> None:
@@ -786,15 +787,15 @@ def _dummy_run(self,
786787
self.share_inputs["step_idx"],
787788
self.share_inputs["stop_flags"],
788789
)
789-
sampled_token_ids = self.sampler(logits,
790+
sampler_output = self.sampler(logits,
790791
self.sampling_metadata)
791792
if self.parallel_config.tensor_parallel_degree > 1:
792-
paddle.distributed.broadcast(sampled_token_ids, 0)
793+
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
793794
else:
794795
self.sampler(logits, self.sampling_metadata,
795796
self.parallel_config.max_model_len,
796797
self.share_inputs)
797-
sampled_token_ids = None
798+
sampler_output = None
798799
if self.parallel_config.tensor_parallel_degree > 1:
799800
paddle.distributed.broadcast(
800801
self.share_inputs["accept_tokens"], 0)
@@ -834,7 +835,7 @@ def _dummy_run(self,
834835
accept_num=self.share_inputs["accept_num"]
835836
if self.speculative_decoding else None)
836837

837-
post_process(sampled_token_ids=sampled_token_ids,
838+
post_process(sampler_output=sampler_output,
838839
model_output=model_output_data,
839840
speculative_decoding=self.speculative_decoding,
840841
skip_save_output=True)
@@ -1021,18 +1022,18 @@ class at the server level, which is too granular for ModelRunner.
10211022
self.share_inputs["step_idx"],
10221023
self.share_inputs["stop_flags"],
10231024
)
1024-
sampled_token_ids = self.sampler(
1025+
sampler_output = self.sampler(
10251026
logits,
10261027
self.sampling_metadata,
10271028
skip_idx_list,
10281029
)
10291030
if self.parallel_config.tensor_parallel_degree > 1:
1030-
paddle.distributed.broadcast(sampled_token_ids, 0)
1031+
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
10311032

10321033
else:
10331034
self.sampler(logits, self.sampling_metadata,
10341035
self.parallel_config.max_model_len, self.share_inputs)
1035-
sampled_token_ids = None
1036+
sampler_output = None
10361037
if self.parallel_config.tensor_parallel_degree > 1:
10371038
paddle.distributed.broadcast(
10381039
self.share_inputs["accept_tokens"], 0)
@@ -1075,7 +1076,7 @@ class at the server level, which is too granular for ModelRunner.
10751076
skip_save_output = True
10761077
else:
10771078
skip_save_output = False
1078-
post_process(sampled_token_ids=sampled_token_ids,
1079+
post_process(sampler_output=sampler_output,
10791080
model_output=model_output_data,
10801081
save_each_rank=self.parallel_config.use_ep,
10811082
speculative_decoding=self.speculative_decoding,

0 commit comments

Comments
 (0)