Skip to content

Commit 269028d

Browse files
committed
fix
1 parent 7565dde commit 269028d

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

src/petals/client/inference_session.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8585

8686
def step(
8787
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
88-
step_id: str, last_validated_position: int
88+
step_id: str, start_from_position: int
8989
) -> torch.Tensor:
9090
"""
9191
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -95,11 +95,11 @@ def step(
9595
if self.closed:
9696
raise Exception("Session is closed, cannot perform step")
9797

98-
if last_validated_position is not None:
99-
assert last_validated_position <= self._position
100-
self._position = last_validated_position
101-
if self.history is not None and self.history.shape[1] >= last_validated_position:
102-
self.history = self.history[:, :last_validated_position, :] if last_validated_position > 0 else None
98+
if start_from_position is not None:
99+
assert start_from_position <= self._position
100+
self._position = start_from_position
101+
if self.history is not None and self.history.shape[1] >= start_from_position:
102+
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
103103

104104
n_input_tokens = inputs.shape[1]
105105
if self.history is None:
@@ -122,8 +122,8 @@ def step(
122122
request_metadata = dict(session_id=self.session_id, step_id=step_id)
123123
if not self.stepped:
124124
request_metadata.update(self.session_metadata)
125-
if last_validated_position is not None:
126-
request_metadata["last_validated_position"] = last_validated_position
125+
if start_from_position is not None:
126+
request_metadata["start_from_position"] = start_from_position
127127
elif self.config.use_server_to_server:
128128
next_servers = self._collect_next_servers()
129129
if next_servers:
@@ -267,11 +267,11 @@ def __enter__(self) -> "InferenceSession":
267267

268268
def step(
269269
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
270-
hypo_ids: Optional[torch.Tensor] = None, last_validated_position: Optional[int] = None
270+
hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None
271271
) -> torch.Tensor:
272272

273-
if last_validated_position is not None:
274-
self._position = last_validated_position
273+
if start_from_position is not None:
274+
self._position = start_from_position
275275

276276
assert not self._closed
277277
if torch.is_grad_enabled():
@@ -318,7 +318,7 @@ def step(
318318
server_session = self._server_sessions[server_idx]
319319
inputs = server_session.step(
320320
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
321-
step_id=step_id, last_validated_position=last_validated_position
321+
step_id=step_id, start_from_position=start_from_position
322322
)
323323

324324
server_idx += 1

src/petals/server/block_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ async def iterate_rpc_inference(
160160
point_per_piece = points / max_length if max_length > 0 else 0.0
161161

162162
async for request, step_metadata in input_iterator:
163-
if "last_validated_position" in step_metadata:
164-
last_validated_position = step_metadata["last_validated_position"]
165-
assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}"
166-
prefix_length = last_validated_position
163+
if "start_from_position" in step_metadata:
164+
start_from_position = step_metadata["start_from_position"]
165+
assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}"
166+
prefix_length = start_from_position
167167

168168
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
169169
if args_structure is not None:

tests/test_speculative_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
2626
with torch.inference_mode():
2727
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
2828
initial_outputs_inference = sess.step(inputs)
29-
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], last_validated_position=2)
29+
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
3030
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
3131

3232
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)

0 commit comments

Comments
 (0)