37
37
from select import select
38
38
from socket import create_connection , SHUT_RDWR , error as SocketError
39
39
from struct import pack as struct_pack , unpack as struct_unpack , unpack_from as struct_unpack_from
40
- from threading import Lock
40
+ from threading import RLock
41
41
42
42
from .constants import DEFAULT_USER_AGENT , KNOWN_HOSTS , MAGIC_PREAMBLE , TRUST_DEFAULT , TRUST_ON_FIRST_USE
43
43
from .exceptions import ProtocolError , Unauthorized , ServiceUnavailable
@@ -378,15 +378,26 @@ class ConnectionPool(object):
378
378
""" A collection of connections to one or more server addresses.
379
379
"""
380
380
381
+ closed = False
382
+
381
383
def __init__ (self , connector ):
382
384
self .connector = connector
383
385
self .connections = {}
384
- self .lock = Lock ()
386
+ self .lock = RLock ()
387
+
388
+ def __enter__ (self ):
389
+ return self
390
+
391
+ def __exit__ (self , exc_type , exc_value , traceback ):
392
+ self .close ()
385
393
386
394
def acquire (self , address ):
387
395
""" Acquire a connection to a given address from the pool.
388
396
This method is thread safe.
389
397
"""
398
+ if self .closed :
399
+ raise ServiceUnavailable ("This connection pool is closed so no new "
400
+ "connections may be acquired" )
390
401
with self .lock :
391
402
try :
392
403
connections = self .connections [address ]
@@ -411,18 +422,25 @@ def release(self, connection):
411
422
with self .lock :
412
423
connection .in_use = False
413
424
425
+ def remove (self , address ):
426
+ """ Remove an address from the connection pool, if present, closing
427
+ all connections to that address.
428
+ """
429
+ with self .lock :
430
+ for connection in self .connections .pop (address , ()):
431
+ try :
432
+ connection .close ()
433
+ except IOError :
434
+ pass
435
+
414
436
def close (self ):
415
437
""" Close all connections and empty the pool.
416
438
This method is thread safe.
417
439
"""
418
440
with self .lock :
419
- for _ , connections in self .connections .items ():
420
- for connection in connections :
421
- try :
422
- connection .close ()
423
- except IOError :
424
- pass
425
- self .connections .clear ()
441
+ self .closed = True
442
+ for address in list (self .connections ):
443
+ self .remove (address )
426
444
427
445
428
446
class CertificateStore (object ):
0 commit comments