Skip to content

Commit 8b357cd

Browse files
committed
use the base Protocol
1 parent 28afc38 commit 8b357cd

File tree

1 file changed

+30
-67
lines changed

1 file changed

+30
-67
lines changed

pymongo/network_layer.py

Lines changed: 30 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
import struct
2323
import sys
2424
import time
25-
from asyncio import BaseTransport, BufferedProtocol, Future, Transport
26-
from dataclasses import dataclass
25+
from asyncio import BaseTransport, BufferedProtocol, Future, Protocol, Transport
2726
from typing import (
2827
TYPE_CHECKING,
2928
Any,
@@ -251,7 +250,7 @@ def recv_into(self, buffer: bytes) -> int:
251250
return self.conn.recv_into(buffer)
252251

253252

254-
class PyMongoBaseProtocol(BufferedProtocol):
253+
class PyMongoBaseProtocol(Protocol):
255254
def __init__(self, timeout: Optional[float] = None):
256255
self.transport: Transport = None # type: ignore[assignment]
257256
self._timeout = timeout
@@ -293,7 +292,7 @@ async def read(self, *args: Any) -> Any:
293292
raise NotImplementedError
294293

295294

296-
class PyMongoProtocol(PyMongoBaseProtocol):
295+
class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol):
297296
def __init__(self, timeout: Optional[float] = None):
298297
super().__init__(timeout)
299298
# Each message is reader in 2-3 parts: header, compression header, and message body
@@ -477,17 +476,10 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
477476
self._done_messages.append(msg)
478477

479478

480-
@dataclass
481-
class KMSBuffer:
482-
buffer: memoryview
483-
start_index: int
484-
end_index: int
485-
486-
487479
class PyMongoKMSProtocol(PyMongoBaseProtocol):
488480
def __init__(self, timeout: Optional[float] = None):
489481
super().__init__(timeout)
490-
self._buffers: collections.deque[KMSBuffer] = collections.deque()
482+
self._buffers: collections.deque[memoryview[bytes]] = collections.deque()
491483
self._bytes_ready = 0
492484
self._pending_reads: collections.deque[int] = collections.deque()
493485
self._pending_listeners: collections.deque[Future[Any]] = collections.deque()
@@ -498,6 +490,24 @@ def connection_made(self, transport: BaseTransport) -> None:
498490
"""
499491
self.transport = transport # type: ignore[assignment]
500492

493+
def data_received(self, data: bytes) -> None:
494+
if self._connection_lost:
495+
return
496+
497+
self._bytes_ready += len(data)
498+
self._buffers.append(memoryview[data])
499+
500+
if not len(self._pending_reads):
501+
return
502+
503+
bytes_needed = self._pending_reads.popleft()
504+
data = self._read(bytes_needed)
505+
waiter = self._pending_listeners.popleft()
506+
waiter.set_result(data)
507+
508+
def eof_received(self):
509+
self.close(OSError("connection closed"))
510+
501511
async def read(self, bytes_needed: int) -> bytes:
502512
"""Read up to the requested bytes from this connection."""
503513
# Note: all reads are "up-to" bytes_needed because we don't know if the kms_context
@@ -521,51 +531,13 @@ async def read(self, bytes_needed: int) -> bytes:
521531
self._pending_listeners.append(read_waiter)
522532
return await read_waiter
523533

524-
def get_buffer(self, sizehint: int) -> memoryview:
525-
"""Called to allocate a new receive buffer.
526-
The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
527-
If any data does not fit into the returned buffer, this method will be called again until
528-
either no data remains or an empty buffer is returned.
529-
"""
530-
# Reuse the active buffer if it has space.
531-
# Allocate a bit more than the max response size for an AWS KMS response.
532-
sizehint = max(sizehint, 16384)
533-
if len(self._buffers):
534-
buffer = self._buffers[-1]
535-
if len(buffer.buffer) - buffer.end_index > sizehint:
536-
return buffer.buffer[buffer.end_index :]
537-
buffer = KMSBuffer(memoryview(bytearray(sizehint)), 0, 0)
538-
self._buffers.append(buffer)
539-
return buffer.buffer
540-
541534
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
542535
while self._pending_listeners:
543536
fut = self._pending_listeners.popleft()
544537
fut.set_result(b"")
545538

546-
def buffer_updated(self, nbytes: int) -> None:
547-
"""Called when the buffer was updated with the received data"""
548-
# Wrote 0 bytes into a non-empty buffer, signal connection closed
549-
if nbytes == 0:
550-
self.close(OSError("connection closed"))
551-
return
552-
if self._connection_lost:
553-
return
554-
self._bytes_ready += nbytes
555-
556-
# Update the length of the current buffer.
557-
self._buffers[-1].end_index += nbytes
558-
559-
if not len(self._pending_reads):
560-
return
561-
562-
bytes_needed = self._pending_reads.popleft()
563-
data = self._read(bytes_needed)
564-
waiter = self._pending_listeners.popleft()
565-
waiter.set_result(data)
566-
567539
def _read(self, bytes_needed: int) -> memoryview:
568-
"""Read bytes from the buffer."""
540+
"""Read bytes."""
569541
# Send the bytes to the listener.
570542
if self._bytes_ready < bytes_needed:
571543
bytes_needed = self._bytes_ready
@@ -576,26 +548,17 @@ def _read(self, bytes_needed: int) -> memoryview:
576548
out_index = 0
577549
while n_remaining > 0:
578550
buffer = self._buffers.popleft()
579-
buffer_remaining = buffer.end_index - buffer.start_index
551+
buf_size = len(buffer)
580552
# if we didn't exhaust the buffer, read the partial data and return the buffer.
581-
if buffer_remaining > n_remaining:
582-
output_buf[out_index : n_remaining + out_index] = buffer.buffer[
583-
buffer.start_index : buffer.start_index + n_remaining
584-
]
585-
buffer.start_index += n_remaining
553+
if buf_size > n_remaining:
554+
output_buf[out_index : n_remaining + out_index] = buffer[:n_remaining]
586555
n_remaining = 0
587-
self._buffers.appendleft(buffer)
556+
self._buffers.appendleft(buffer[n_remaining:])
588557
# otherwise exhaust the buffer.
589558
else:
590-
output_buf[out_index : out_index + buffer_remaining] = buffer.buffer[
591-
buffer.start_index : buffer.end_index
592-
]
593-
out_index += buffer_remaining
594-
n_remaining -= buffer_remaining
595-
# if this is the only buffer, add it back to the queue.
596-
if not len(self._buffers):
597-
buffer.start_index = buffer.end_index
598-
self._buffers.appendleft(buffer)
559+
output_buf[out_index : out_index + buf_size] = buffer[:]
560+
out_index += buf_size
561+
n_remaining -= buf_size
599562
return memoryview(output_buf)
600563

601564

0 commit comments

Comments
 (0)