@@ -85,7 +85,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
85
85
86
86
def step (
87
87
self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * ,
88
- step_id : str , last_validated_position : int
88
+ step_id : str , start_from_position : int
89
89
) -> torch .Tensor :
90
90
"""
91
91
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -95,11 +95,11 @@ def step(
95
95
if self .closed :
96
96
raise Exception ("Session is closed, cannot perform step" )
97
97
98
- if last_validated_position is not None :
99
- assert last_validated_position <= self ._position
100
- self ._position = last_validated_position
101
- if self .history is not None and self .history .shape [1 ] >= last_validated_position :
102
- self .history = self .history [:, :last_validated_position , :] if last_validated_position > 0 else None
98
+ if start_from_position is not None :
99
+ assert start_from_position <= self ._position
100
+ self ._position = start_from_position
101
+ if self .history is not None and self .history .shape [1 ] >= start_from_position :
102
+ self .history = self .history [:, :start_from_position , :] if start_from_position > 0 else None
103
103
104
104
n_input_tokens = inputs .shape [1 ]
105
105
if self .history is None :
@@ -122,8 +122,8 @@ def step(
122
122
request_metadata = dict (session_id = self .session_id , step_id = step_id )
123
123
if not self .stepped :
124
124
request_metadata .update (self .session_metadata )
125
- if last_validated_position is not None :
126
- request_metadata ["last_validated_position " ] = last_validated_position
125
+ if start_from_position is not None :
126
+ request_metadata ["start_from_position " ] = start_from_position
127
127
elif self .config .use_server_to_server :
128
128
next_servers = self ._collect_next_servers ()
129
129
if next_servers :
@@ -267,11 +267,11 @@ def __enter__(self) -> "InferenceSession":
267
267
268
268
def step (
269
269
self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None ,
270
- hypo_ids : Optional [torch .Tensor ] = None , last_validated_position : Optional [int ] = None
270
+ hypo_ids : Optional [torch .Tensor ] = None , start_from_position : Optional [int ] = None
271
271
) -> torch .Tensor :
272
272
273
- if last_validated_position is not None :
274
- self ._position = last_validated_position
273
+ if start_from_position is not None :
274
+ self ._position = start_from_position
275
275
276
276
assert not self ._closed
277
277
if torch .is_grad_enabled ():
@@ -318,7 +318,7 @@ def step(
318
318
server_session = self ._server_sessions [server_idx ]
319
319
inputs = server_session .step (
320
320
inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids ,
321
- step_id = step_id , last_validated_position = last_validated_position
321
+ step_id = step_id , start_from_position = start_from_position
322
322
)
323
323
324
324
server_idx += 1
0 commit comments