@@ -766,34 +766,30 @@ def sqs(self, queue=None):
766
766
return c
767
767
768
768
def _handle_sts_session (self , queue , q ):
769
+ region = q .get ('region' , self .region )
769
770
if not hasattr (self , 'sts_expiration' ): # STS token - token init
770
- sts_creds = self .generate_sts_session_token (
771
- self .transport_options .get ('sts_role_arn' ),
772
- self .transport_options .get ('sts_token_timeout' , 900 ))
773
- self .sts_expiration = sts_creds ['Expiration' ]
774
- c = self ._predefined_queue_clients [queue ] = self .new_sqs_client (
775
- region = q .get ('region' , self .region ),
776
- access_key_id = sts_creds ['AccessKeyId' ],
777
- secret_access_key = sts_creds ['SecretAccessKey' ],
778
- session_token = sts_creds ['SessionToken' ],
779
- )
780
- return c
771
+ return self ._new_predefined_queue_client_with_sts_session (queue , region )
781
772
# STS token - refresh if expired
782
773
elif self .sts_expiration .replace (tzinfo = None ) < datetime .utcnow ():
783
- sts_creds = self .generate_sts_session_token (
784
- self .transport_options .get ('sts_role_arn' ),
785
- self .transport_options .get ('sts_token_timeout' , 900 ))
786
- self .sts_expiration = sts_creds ['Expiration' ]
787
- c = self ._predefined_queue_clients [queue ] = self .new_sqs_client (
788
- region = q .get ('region' , self .region ),
789
- access_key_id = sts_creds ['AccessKeyId' ],
790
- secret_access_key = sts_creds ['SecretAccessKey' ],
791
- session_token = sts_creds ['SessionToken' ],
792
- )
793
- return c
774
+ return self ._new_predefined_queue_client_with_sts_session (queue , region )
794
775
else : # STS token - ruse existing
776
+ if queue not in self ._predefined_queue_clients :
777
+ return self ._new_predefined_queue_client_with_sts_session (queue , region )
795
778
return self ._predefined_queue_clients [queue ]
796
779
780
+ def _new_predefined_queue_client_with_sts_session (self , queue , region ):
781
+ sts_creds = self .generate_sts_session_token (
782
+ self .transport_options .get ('sts_role_arn' ),
783
+ self .transport_options .get ('sts_token_timeout' , 900 ))
784
+ self .sts_expiration = sts_creds ['Expiration' ]
785
+ c = self ._predefined_queue_clients [queue ] = self .new_sqs_client (
786
+ region = region ,
787
+ access_key_id = sts_creds ['AccessKeyId' ],
788
+ secret_access_key = sts_creds ['SecretAccessKey' ],
789
+ session_token = sts_creds ['SessionToken' ],
790
+ )
791
+ return c
792
+
797
793
def generate_sts_session_token (self , role_arn , token_expiry_seconds ):
798
794
sts_client = boto3 .client ('sts' )
799
795
sts_policy = sts_client .assume_role (
0 commit comments