Skip to content

Commit b176526

Browse files
authored
remove num_tokens from EngineOutput (#4088)
1 parent 2d759a4 commit b176526

File tree

7 files changed

+10
-17
lines changed

7 files changed

+10
-17
lines changed

benchmark/profile_throughput.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ async def _inference(self, req_queue: Queue, session_id: int, temperature: float
178178
stream_output=stream_output)
179179
try:
180180
async for outputs in generator:
181-
n_token += outputs.num_token
181+
n_token += len(outputs.token_ids)
182182
token_ids += outputs.token_ids
183183
if not skip_detokenize:
184184
_, state = self.tokenizer.detokenize_incrementally(token_ids, state)

lmdeploy/messages.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,11 @@ class RequestMetrics:
526526

527527
@dataclass
528528
class EngineOutput:
529-
"""Engine output for turbomind/pytorch engine.
529+
"""Engine output from turbomind/pytorch engine.
530530
531531
Args:
532532
status (ResponseType): the response type.
533533
token_ids (List[int]): the newly generated token ids in each iteration.
534-
num_token (int): the newly generated token number, equal to `len(token_ids)`
535534
logprobs (List[Dict[int, float]]): the top logprobs for each output
536535
position.
537536
cache_block_ids (List[int]): send cache blocks back for migration in
@@ -540,7 +539,6 @@ class EngineOutput:
540539
"""
541540
status: ResponseType
542541
token_ids: List[int]
543-
num_token: int
544542
logprobs: List[Dict[int, float]] = None
545543
logits: torch.Tensor = None
546544
last_hidden_state: torch.Tensor = None

lmdeploy/metrics/stats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState):
198198
outputs (EngineOutput): The output from the engine containing information about the current iteration.
199199
req_state (RequestState): The state of the request, including timestamps and token counts.
200200
"""
201-
new_generation_tokens = outputs.num_token
201+
new_generation_tokens = len(outputs.token_ids)
202202
if new_generation_tokens == 0:
203203
return
204204
self.new_generation_tokens = new_generation_tokens
@@ -213,7 +213,7 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState):
213213
# update the latest token generation time
214214
req_state.lastest_token_time = outputs.req_metrics.token_timestamp
215215
# update the number of generated tokens
216-
req_state.generation_tokens += outputs.num_token
216+
req_state.generation_tokens += new_generation_tokens
217217

218218
if outputs.status != ResponseType.SUCCESS:
219219
req_state.finish_reason = outputs.status

lmdeploy/pytorch/engine/engine_instance.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def async_stream_infer(self,
126126
int: The number of the output tokens.
127127
"""
128128
if len(input_ids) > self.max_input_len:
129-
yield EngineOutput(ResponseType.INPUT_LENGTH_ERROR, [], 0)
129+
yield EngineOutput(ResponseType.INPUT_LENGTH_ERROR, [])
130130
return
131131
gen_config = gen_config or GenerationConfig()
132132
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
@@ -158,7 +158,6 @@ async def async_stream_infer(self,
158158
logger.debug(f'session[{session_id}] success: num_out_ids={num_ids}.')
159159
yield EngineOutput(resp.type,
160160
token_ids[output_offset:],
161-
num_ids,
162161
cache_block_ids=cache_block_ids,
163162
req_metrics=req_metrics,
164163
logprobs=logprobs)
@@ -171,15 +170,14 @@ async def async_stream_infer(self,
171170
logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')
172171
yield EngineOutput(resp.type,
173172
token_ids[output_offset:],
174-
num_ids,
175173
logits=logits,
176174
cache_block_ids=cache_block_ids,
177175
req_metrics=req_metrics,
178176
logprobs=logprobs)
179177
break
180178
else:
181179
logger.debug(f'session[{session_id}] failed.')
182-
yield EngineOutput(resp.type, [], 0)
180+
yield EngineOutput(resp.type, [])
183181
break
184182

185183
async def async_infer(self,

lmdeploy/pytorch/engine/mp_engine/base_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(self):
138138

139139
def get(self, stream_id):
140140
if stream_id not in self._output:
141-
self._output[stream_id] = EngineOutput(status=None, token_ids=[], num_token=0, logprobs=[])
141+
self._output[stream_id] = EngineOutput(status=None, token_ids=[], logprobs=[])
142142
return self._output[stream_id]
143143

144144
def add(self, stream_id, result):
@@ -154,5 +154,4 @@ def pop(self, stream_id, result):
154154
output = self._output.pop(stream_id)
155155
result.token_ids = output.token_ids or []
156156
result.logprobs = output.logprobs or None
157-
result.num_token = len(output.token_ids)
158157
return result

lmdeploy/serve/async_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def is_error(status):
854854
if is_error(outputs.status):
855855
break
856856

857-
output_len = outputs.num_token
857+
output_len = len(outputs.token_ids)
858858
if hit_stop_token:
859859
continue
860860

lmdeploy/turbomind/turbomind.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,6 @@ async def async_stream_infer(self,
760760
state = None
761761

762762
output_ids = []
763-
output_len = 0
764763
prev_len = step + input_len
765764
try:
766765
while True:
@@ -782,8 +781,7 @@ async def async_stream_infer(self,
782781
continue
783782

784783
output_ids = output_ids_buf[prev_len:seq_len].tolist()
785-
output_len = seq_len - prev_len
786-
output = EngineOutput(ret_status, output_ids, output_len)
784+
output = EngineOutput(ret_status, output_ids)
787785

788786
for f in extra_fs:
789787
f(output, seq_len)
@@ -811,7 +809,7 @@ async def async_stream_infer(self,
811809
logger.info(f'[async_stream_infer] session {session_id} done')
812810

813811
def _get_error_output(self, status):
814-
return EngineOutput(status=self.errcode_map[status], token_ids=[], num_token=0)
812+
return EngineOutput(status=self.errcode_map[status], token_ids=[])
815813

816814
def _get_generation_config(self, cfg: GenerationConfig):
817815
c = _tm.GenerationConfig()

0 commit comments

Comments
 (0)