@@ -280,12 +280,12 @@ def close(self):
280
280
self .conn .parent .stop (self )
281
281
282
282
@staticmethod
283
- def _wait_to_timeout (wait ):
283
+ def _wait_to_timeout (wait ) -> Optional [ float ] :
284
284
if isinstance (wait , bool ):
285
- return None if wait else 0
285
+ return None if wait else 0.0
286
286
287
- if isinstance (wait , numbers .Real ) and wait >= 0 :
288
- return wait
287
+ if isinstance (wait , numbers .Real ) and float ( wait ) >= 0 :
288
+ return float ( wait )
289
289
290
290
raise ReqlDriverError (f"Invalid wait timeout '{ wait } '" )
291
291
@@ -397,7 +397,7 @@ def __next__(self):
397
397
398
398
@staticmethod
399
399
def _empty_error ():
400
- return DefaultCursorEmpty ()
400
+ return DefaultCursorEmpty
401
401
402
402
def _get_next (self , timeout : Optional [float ] = None ):
403
403
deadline = None if timeout is None else time .time () + timeout
@@ -425,44 +425,33 @@ def __init__(self, parent: "ConnectionInstance", timeout: int):
425
425
self .port : int = parent .parent .port
426
426
self .ssl : dict = parent .parent .ssl
427
427
self ._read_buffer : Optional [bytes ] = None
428
+ self .__socket : Optional [Union [socket .socket , ssl .SSLSocket ]] = None
428
429
429
430
deadline : float = time .time () + timeout
430
431
431
432
try :
432
433
self .__socket = socket .create_connection ((self .host , self .port ), timeout )
433
434
434
- self .socket .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
435
- self .socket .setsockopt (socket .SOL_SOCKET , socket .SO_KEEPALIVE , 1 )
435
+
436
+ sock = self .__socket
437
+ sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
438
+ sock .setsockopt (socket .SOL_SOCKET , socket .SO_KEEPALIVE , 1 )
436
439
437
440
if len (self .ssl ) > 0 :
438
441
try :
439
- if hasattr (
440
- ssl , "SSLContext"
441
- ): # Python2.7 and 3.2+, or backports.ssl
442
- ssl_context = ssl .SSLContext (ssl .PROTOCOL_SSLv23 )
443
- if hasattr (ssl_context , "options" ):
444
- ssl_context .options |= getattr (ssl , "OP_NO_SSLv2" , 0 )
445
- ssl_context .options |= getattr (ssl , "OP_NO_SSLv3" , 0 )
446
- ssl_context .verify_mode = ssl .CERT_REQUIRED
447
- ssl_context .check_hostname = (
448
- True # redundant with ssl.match_hostname
449
- )
450
- ssl_context .load_verify_locations (self .ssl ["ca_certs" ])
442
+
443
+ ssl_context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
444
+ ssl_context .check_hostname = True
445
+ ssl_context .verify_mode = ssl .CERT_REQUIRED
446
+ ssl_context .load_verify_locations (self .ssl ["ca_certs" ])
447
+
448
+ if self .__socket is not None :
451
449
self .socket = ssl_context .wrap_socket (
452
- self .socket , server_hostname = self .host
453
- )
454
- else : # this does not disable SSLv2 or SSLv3
455
- # TODO: Replace the deprecated wrap_socket
456
- self .socket = (
457
- ssl .wrap_socket ( # pylint: disable=deprecated-method
458
- self .socket ,
459
- cert_reqs = ssl .CERT_REQUIRED ,
460
- ssl_version = ssl .PROTOCOL_SSLv23 ,
461
- ca_certs = self .ssl ["ca_certs" ],
462
- )
450
+ self .__socket , server_hostname = self .host
463
451
)
464
452
except IOError as err :
465
- self .socket .close ()
453
+ if self .__socket is not None :
454
+ self .__socket .close ()
466
455
467
456
if "EOF occurred in violation of protocol" in str (
468
457
err
@@ -482,14 +471,6 @@ def __init__(self, parent: "ConnectionInstance", timeout: int):
482
471
raise ReqlDriverError (
483
472
f"SSL handshake failed (see server log for more information): { err } "
484
473
) from err
485
- try :
486
- # TODO: Replace the deprecated match_hostname
487
- ssl .match_hostname ( # pylint: disable=deprecated-method
488
- self .socket .getpeercert (), hostname = self .host
489
- )
490
- except ssl .CertificateError :
491
- self .socket .close ()
492
- raise
493
474
494
475
parent .parent .handshake .reset ()
495
476
response = None
@@ -530,24 +511,24 @@ def __init__(self, parent: "ConnectionInstance", timeout: int):
530
511
) from exc
531
512
532
513
@property
533
- def socket (self ) -> Union [socket .socket , ssl .SSLSocket ]:
514
+ def socket (self ) -> "Optional[ Union[socket.socket, ssl.SSLSocket]]" :
534
515
"""
535
516
Return the wrapped socket.
536
517
"""
537
518
return self .__socket
538
519
539
520
@socket .setter
540
- def socket (self , value : "socket.socket" ):
521
+ def socket (self , value : "Optional[Union[ socket.socket, ssl.SSLSocket]] " ):
541
522
"""
542
523
Set the socket instance.
543
524
"""
544
- self ._socket = value
525
+ self .__socket = value
545
526
546
527
def is_open (self ):
547
528
"""
548
529
Return if the connection is open.
549
530
"""
550
- return self .socket is not None
531
+ return self .__socket is not None
551
532
552
533
def close (self ):
553
534
"""
@@ -557,8 +538,9 @@ def close(self):
557
538
return
558
539
559
540
try :
560
- self .socket .shutdown (socket .SHUT_RDWR )
561
- self .socket .close ()
541
+ if self .__socket is not None :
542
+ self .__socket .shutdown (socket .SHUT_RDWR )
543
+ self .__socket .close ()
562
544
except ReqlError as exc :
563
545
logger .error (exc .message )
564
546
except Exception as exc : # pylint: disable=broad-except
@@ -574,16 +556,20 @@ def recvall(self, length: int, deadline: Optional[float]):
574
556
timeout : Optional [float ] = (
575
557
None if deadline is None else max (0.0 , deadline - time .time ())
576
558
)
577
- self .socket .settimeout (timeout )
559
+ if self .__socket is not None :
560
+ self .__socket .settimeout (timeout )
578
561
while len (res ) < length :
579
562
while True :
580
563
try :
581
- chunk = self .socket .recv (length - len (res ))
582
- self .socket .settimeout (None )
564
+ if self .__socket is None :
565
+ raise ReqlDriverError ("Socket is None" )
566
+ chunk = self .__socket .recv (length - len (res ))
567
+ self .__socket .settimeout (None )
583
568
break
584
569
except socket .timeout as exc :
585
570
self ._read_buffer = res
586
- self .socket .settimeout (None )
571
+ if self .__socket is not None :
572
+ self .__socket .settimeout (None )
587
573
raise ReqlTimeoutError (self .host , self .port ) from exc
588
574
except IOError as exc :
589
575
if exc .errno == errno .ECONNRESET :
@@ -618,7 +604,9 @@ def sendall(self, data: bytes):
618
604
offset = 0
619
605
while offset < len (data ):
620
606
try :
621
- offset += self .socket .send (data [offset :])
607
+ if self .__socket is None :
608
+ raise ReqlDriverError ("Socket is None" )
609
+ offset += self .__socket .send (data [offset :])
622
610
except IOError as exc :
623
611
if exc .errno == errno .ECONNRESET :
624
612
self .close ()
@@ -683,7 +671,12 @@ def client_port(self) -> Optional[int]:
683
671
if not self .is_open ():
684
672
return None
685
673
686
- return self .socket .socket .getsockname ()[1 ]
674
+ if self .socket is None :
675
+ raise ReqlDriverError ("Socket unexpectedly returned none." )
676
+ socket_obj = self .socket .socket
677
+ if socket_obj is not None :
678
+ return socket_obj .getsockname ()[1 ]
679
+ return None
687
680
688
681
def client_address (self ) -> Optional [str ]:
689
682
"""
@@ -695,7 +688,12 @@ def client_address(self) -> Optional[str]:
695
688
if not self .is_open ():
696
689
return None
697
690
698
- return self .socket .socket .getsockname ()[0 ]
691
+ if self .socket is None :
692
+ raise ReqlDriverError ("Socket unexpectedly returned none." )
693
+ socket_obj = self .socket .socket
694
+ if socket_obj is not None :
695
+ return socket_obj .getsockname ()[0 ]
696
+ return None
699
697
700
698
def connect (self , timeout : int ) -> "Connection" :
701
699
"""
@@ -728,7 +726,12 @@ def close(self, noreply_wait=False, token=None) -> None:
728
726
729
727
try :
730
728
if noreply_wait :
731
- query = Query (PbQuery .QueryType .NOREPLY_WAIT , token , None , None )
729
+ query = Query (
730
+ PbQuery .QueryType .NOREPLY_WAIT ,
731
+ token or self .parent ._new_token (), # pylint: disable=protected-access
732
+ None ,
733
+ None
734
+ )
732
735
self .run_query (query , False )
733
736
finally :
734
737
if self .socket is None :
@@ -806,7 +809,7 @@ def read_response(self, query, deadline=None) -> Optional[Response]:
806
809
807
810
res = None
808
811
809
- cursor : Cursor = self .cursor_cache .get (res_token )
812
+ cursor : Optional [ Cursor ] = self .cursor_cache .get (res_token )
810
813
if cursor is not None :
811
814
# Construct response
812
815
cursor .extend (res_buf )
@@ -843,7 +846,7 @@ class Connection: # pylint: disable=too-many-instance-attributes
843
846
_json_decoder = ReqlDecoder
844
847
_json_encoder = ReqlEncoder
845
848
846
- # pylint: disable=too-many-arguments
849
+ # pylint: disable=too-many-arguments,too-many-positional-arguments
847
850
def __init__ ( # nosec
848
851
self ,
849
852
conn_type ,
@@ -853,7 +856,7 @@ def __init__( # nosec
853
856
user : str ,
854
857
password : str = "" ,
855
858
timeout : int = 0 ,
856
- ssl : dict = None , # pylint: disable=redefined-outer-name
859
+ ssl : Optional [ dict ] = None , # pylint: disable=redefined-outer-name
857
860
_handshake_version : Type [BaseHandshake ] = HandshakeV1_0 ,
858
861
** kwargs ,
859
862
):
@@ -988,6 +991,8 @@ def noreply_wait(self):
988
991
"""
989
992
self .check_open ()
990
993
query = Query (PbQuery .QueryType .NOREPLY_WAIT , self ._new_token (), None , None )
994
+ if self ._instance is None :
995
+ raise ReqlDriverError ("Connection instance unexpectedly none." )
991
996
return self ._instance .run_query (query , False )
992
997
993
998
def server (self ):
@@ -997,6 +1002,8 @@ def server(self):
997
1002
998
1003
self .check_open ()
999
1004
query = Query (PbQuery .QueryType .SERVER_INFO , self ._new_token (), None , None )
1005
+ if self ._instance is None :
1006
+ raise ReqlDriverError ("Connection instance unexpectedly none." )
1000
1007
return self ._instance .run_query (query , False )
1001
1008
1002
1009
def _new_token (self ):
@@ -1012,6 +1019,8 @@ def start(self, term, **kwargs):
1012
1019
if "db" in kwargs or self .db is not None :
1013
1020
kwargs ["db" ] = DB (kwargs .get ("db" , self .db ))
1014
1021
query = Query (PbQuery .QueryType .START , self ._new_token (), term , kwargs )
1022
+ if self ._instance is None :
1023
+ raise ReqlDriverError ("Connection instance unexpectedly none." )
1015
1024
return self ._instance .run_query (query , kwargs .get ("noreply" , False ))
1016
1025
1017
1026
def resume (self , cursor ):
@@ -1020,6 +1029,8 @@ def resume(self, cursor):
1020
1029
"""
1021
1030
self .check_open ()
1022
1031
query = Query (PbQuery .QueryType .CONTINUE , cursor .query .token , None , None )
1032
+ if self ._instance is None :
1033
+ raise ReqlDriverError ("Connection instance unexpectedly none." )
1023
1034
return self ._instance .run_query (query , True )
1024
1035
1025
1036
def stop (self , cursor ):
@@ -1028,6 +1039,8 @@ def stop(self, cursor):
1028
1039
"""
1029
1040
self .check_open ()
1030
1041
query = Query (PbQuery .QueryType .STOP , cursor .query .token , None , None )
1042
+ if self ._instance is None :
1043
+ raise ReqlDriverError ("Connection instance unexpectedly none." )
1031
1044
return self ._instance .run_query (query , True )
1032
1045
1033
1046
def get_json_decoder (self , query ):
@@ -1053,6 +1066,7 @@ def __init__(self, *args, **kwargs):
1053
1066
1054
1067
1055
1068
# pylint: disable=too-many-arguments
1069
+ # pylint: disable=too-many-positional-arguments
1056
1070
def make_connection (
1057
1071
connection_type ,
1058
1072
host = DEFAULT_HOST ,
@@ -1086,10 +1100,10 @@ def make_connection(
1086
1100
port = connection_string .port or port
1087
1101
1088
1102
db = connection_string .path .replace ("/" , "" ) or None
1089
- timeout = query_string .get ("timeout" , DEFAULT_TIMEOUT )
1103
+ timeout_list = query_string .get ("timeout" , [ DEFAULT_TIMEOUT ] )
1090
1104
1091
- if timeout :
1092
- timeout = int (timeout [0 ])
1105
+ if timeout_list :
1106
+ timeout = int (timeout_list [0 ])
1093
1107
1094
1108
conn = connection_type (
1095
1109
host ,
0 commit comments