22
22
import struct
23
23
import sys
24
24
import time
25
- from asyncio import BaseTransport , BufferedProtocol , Future , Transport
26
- from dataclasses import dataclass
25
+ from asyncio import BaseTransport , BufferedProtocol , Future , Protocol , Transport
27
26
from typing import (
28
27
TYPE_CHECKING ,
29
28
Any ,
@@ -251,7 +250,7 @@ def recv_into(self, buffer: bytes) -> int:
251
250
return self .conn .recv_into (buffer )
252
251
253
252
254
- class PyMongoBaseProtocol (BufferedProtocol ):
253
+ class PyMongoBaseProtocol (Protocol ):
255
254
def __init__ (self , timeout : Optional [float ] = None ):
256
255
self .transport : Transport = None # type: ignore[assignment]
257
256
self ._timeout = timeout
@@ -293,7 +292,7 @@ async def read(self, *args: Any) -> Any:
293
292
raise NotImplementedError
294
293
295
294
296
- class PyMongoProtocol (PyMongoBaseProtocol ):
295
+ class PyMongoProtocol (PyMongoBaseProtocol , BufferedProtocol ):
297
296
def __init__ (self , timeout : Optional [float ] = None ):
298
297
super ().__init__ (timeout )
299
298
# Each message is reader in 2-3 parts: header, compression header, and message body
@@ -477,17 +476,10 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
477
476
self ._done_messages .append (msg )
478
477
479
478
480
- @dataclass
481
- class KMSBuffer :
482
- buffer : memoryview
483
- start_index : int
484
- end_index : int
485
-
486
-
487
479
class PyMongoKMSProtocol (PyMongoBaseProtocol ):
488
480
def __init__ (self , timeout : Optional [float ] = None ):
489
481
super ().__init__ (timeout )
490
- self ._buffers : collections .deque [KMSBuffer ] = collections .deque ()
482
+ self ._buffers : collections .deque [memoryview [ bytes ] ] = collections .deque ()
491
483
self ._bytes_ready = 0
492
484
self ._pending_reads : collections .deque [int ] = collections .deque ()
493
485
self ._pending_listeners : collections .deque [Future [Any ]] = collections .deque ()
@@ -498,6 +490,24 @@ def connection_made(self, transport: BaseTransport) -> None:
498
490
"""
499
491
self .transport = transport # type: ignore[assignment]
500
492
493
+ def data_received (self , data : bytes ) -> None :
494
+ if self ._connection_lost :
495
+ return
496
+
497
+ self ._bytes_ready += len (data )
498
+ self ._buffers .append (memoryview [data ])
499
+
500
+ if not len (self ._pending_reads ):
501
+ return
502
+
503
+ bytes_needed = self ._pending_reads .popleft ()
504
+ data = self ._read (bytes_needed )
505
+ waiter = self ._pending_listeners .popleft ()
506
+ waiter .set_result (data )
507
+
508
+ def eof_received (self ):
509
+ self .close (OSError ("connection closed" ))
510
+
501
511
async def read (self , bytes_needed : int ) -> bytes :
502
512
"""Read up to the requested bytes from this connection."""
503
513
# Note: all reads are "up-to" bytes_needed because we don't know if the kms_context
@@ -521,51 +531,13 @@ async def read(self, bytes_needed: int) -> bytes:
521
531
self ._pending_listeners .append (read_waiter )
522
532
return await read_waiter
523
533
524
- def get_buffer (self , sizehint : int ) -> memoryview :
525
- """Called to allocate a new receive buffer.
526
- The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
527
- If any data does not fit into the returned buffer, this method will be called again until
528
- either no data remains or an empty buffer is returned.
529
- """
530
- # Reuse the active buffer if it has space.
531
- # Allocate a bit more than the max response size for an AWS KMS response.
532
- sizehint = max (sizehint , 16384 )
533
- if len (self ._buffers ):
534
- buffer = self ._buffers [- 1 ]
535
- if len (buffer .buffer ) - buffer .end_index > sizehint :
536
- return buffer .buffer [buffer .end_index :]
537
- buffer = KMSBuffer (memoryview (bytearray (sizehint )), 0 , 0 )
538
- self ._buffers .append (buffer )
539
- return buffer .buffer
540
-
541
534
def _resolve_pending (self , exc : Optional [Exception ] = None ) -> None :
542
535
while self ._pending_listeners :
543
536
fut = self ._pending_listeners .popleft ()
544
537
fut .set_result (b"" )
545
538
546
- def buffer_updated (self , nbytes : int ) -> None :
547
- """Called when the buffer was updated with the received data"""
548
- # Wrote 0 bytes into a non-empty buffer, signal connection closed
549
- if nbytes == 0 :
550
- self .close (OSError ("connection closed" ))
551
- return
552
- if self ._connection_lost :
553
- return
554
- self ._bytes_ready += nbytes
555
-
556
- # Update the length of the current buffer.
557
- self ._buffers [- 1 ].end_index += nbytes
558
-
559
- if not len (self ._pending_reads ):
560
- return
561
-
562
- bytes_needed = self ._pending_reads .popleft ()
563
- data = self ._read (bytes_needed )
564
- waiter = self ._pending_listeners .popleft ()
565
- waiter .set_result (data )
566
-
567
539
def _read (self , bytes_needed : int ) -> memoryview :
568
- """Read bytes from the buffer ."""
540
+ """Read bytes."""
569
541
# Send the bytes to the listener.
570
542
if self ._bytes_ready < bytes_needed :
571
543
bytes_needed = self ._bytes_ready
@@ -576,26 +548,17 @@ def _read(self, bytes_needed: int) -> memoryview:
576
548
out_index = 0
577
549
while n_remaining > 0 :
578
550
buffer = self ._buffers .popleft ()
579
- buffer_remaining = buffer . end_index - buffer . start_index
551
+ buf_size = len ( buffer )
580
552
# if we didn't exhaust the buffer, read the partial data and return the buffer.
581
- if buffer_remaining > n_remaining :
582
- output_buf [out_index : n_remaining + out_index ] = buffer .buffer [
583
- buffer .start_index : buffer .start_index + n_remaining
584
- ]
585
- buffer .start_index += n_remaining
553
+ if buf_size > n_remaining :
554
+ output_buf [out_index : n_remaining + out_index ] = buffer [:n_remaining ]
586
555
n_remaining = 0
587
- self ._buffers .appendleft (buffer )
556
+ self ._buffers .appendleft (buffer [ n_remaining :] )
588
557
# otherwise exhaust the buffer.
589
558
else :
590
- output_buf [out_index : out_index + buffer_remaining ] = buffer .buffer [
591
- buffer .start_index : buffer .end_index
592
- ]
593
- out_index += buffer_remaining
594
- n_remaining -= buffer_remaining
595
- # if this is the only buffer, add it back to the queue.
596
- if not len (self ._buffers ):
597
- buffer .start_index = buffer .end_index
598
- self ._buffers .appendleft (buffer )
559
+ output_buf [out_index : out_index + buf_size ] = buffer [:]
560
+ out_index += buf_size
561
+ n_remaining -= buf_size
599
562
return memoryview (output_buf )
600
563
601
564
0 commit comments