3232 speculate_save_output , speculate_set_value_by_flags_and_idx ,
3333 speculate_step_paddle , speculate_step_system_cache ,
3434 speculate_update_v3 , step_paddle , step_system_cache , update_inputs ,
35- step_reschedule )
36- from fastdeploy .worker .output import ModelOutputData
35+ step_reschedule , save_output_topk )
36+ from fastdeploy .worker .output import (ModelOutputData , ModelRunnerOutput ,
37+ SamplerOutput )
3738
3839DISABLE_RECOVER = (envs .FD_DISABLED_RECOVER == "1" )
3940
@@ -109,10 +110,10 @@ def pre_process(
109110 cu_seqlens_k , output_cum_offsets , output_padding_offset )
110111
111112
112- def post_process_normal (sampled_token_ids : paddle . Tensor ,
113+ def post_process_normal (sampler_output : SamplerOutput ,
113114 model_output : ModelOutputData ,
114115 save_each_rank : bool = False ,
115- skip_save_output : bool = False ) -> None :
116+ skip_save_output : bool = False ) -> ModelRunnerOutput :
116117 """ Post-processing steps after completing a single token generation. """
117118 # 1. Set stop value
118119 paddle .assign (
@@ -130,7 +131,8 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
130131 model_output .stop_flags ,
131132 )
132133 # TODO(gongshaotian): Add use_stop_seqs
133- set_stop_value_multi_ends (sampled_token_ids , model_output .stop_flags ,
134+ set_stop_value_multi_ends (sampler_output .sampled_token_ids ,
135+ model_output .stop_flags ,
134136 model_output .seq_lens_this_time ,
135137 model_output .eos_token_id ,
136138 model_output .next_tokens , False ) # multi ends
@@ -145,18 +147,26 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
145147 model_output .seq_lens_decoder ,
146148 model_output .input_ids ,
147149 model_output .stop_nums ,
148- sampled_token_ids ,
150+ sampler_output . sampled_token_ids ,
149151 model_output .is_block_step ,
150152 )
151153 # 3. Transmit the model's output and stop generation signal via message queue.
152154 # In the future, we will abandon this approach.
153155 if not skip_save_output :
154- save_output (
155- sampled_token_ids ,
156- model_output .not_need_stop ,
157- model_output .mp_rank ,
158- save_each_rank , # save_each_rank
159- )
156+ if sampler_output .logprobs_tensors is None :
157+ save_output (
158+ sampler_output .sampled_token_ids ,
159+ model_output .not_need_stop ,
160+ model_output .mp_rank ,
161+ save_each_rank , # save_each_rank
162+ )
163+ else :
164+ save_output_topk (
165+ sampler_output .logprobs_tensors .logprob_token_ids ,
166+ sampler_output .logprobs_tensors .logprobs ,
167+ model_output .not_need_stop ,
168+ model_output .mp_rank ,
169+ )
160170
161171
162172def post_process_specualate (model_output , skip_save_output : bool = False ):
@@ -201,7 +211,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
201211 )
202212
203213
204- def post_process (sampled_token_ids : paddle . Tensor ,
214+ def post_process (sampler_output : SamplerOutput ,
205215 model_output : ModelOutputData ,
206216 save_each_rank : bool = False ,
207217 speculative_decoding : bool = False ,
@@ -210,7 +220,7 @@ def post_process(sampled_token_ids: paddle.Tensor,
210220 if speculative_decoding :
211221 post_process_specualate (model_output , skip_save_output )
212222 else :
213- post_process_normal (sampled_token_ids , model_output , save_each_rank ,
223+ post_process_normal (sampler_output , model_output , save_each_rank ,
214224 skip_save_output )
215225
216226
0 commit comments