Skip to content

Commit 83b296f

Browse files
authored
fix(sqs): don't crash on multiple predefined queues with aws sts session (#2224)
* chore(sqs): write the test case for multiple predefined queues with aws sts session * fix(sqs): don't crash on multiple predefined queues with aws sts session * refactor(sqs): make _new_predefined_queue_client_with_sts_session()
1 parent 4c64cdd commit 83b296f

File tree

2 files changed

+46
-22
lines changed

2 files changed

+46
-22
lines changed

kombu/transport/SQS.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -766,34 +766,30 @@ def sqs(self, queue=None):
766766
return c
767767

768768
def _handle_sts_session(self, queue, q):
769+
region = q.get('region', self.region)
769770
if not hasattr(self, 'sts_expiration'): # STS token - token init
770-
sts_creds = self.generate_sts_session_token(
771-
self.transport_options.get('sts_role_arn'),
772-
self.transport_options.get('sts_token_timeout', 900))
773-
self.sts_expiration = sts_creds['Expiration']
774-
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
775-
region=q.get('region', self.region),
776-
access_key_id=sts_creds['AccessKeyId'],
777-
secret_access_key=sts_creds['SecretAccessKey'],
778-
session_token=sts_creds['SessionToken'],
779-
)
780-
return c
771+
return self._new_predefined_queue_client_with_sts_session(queue, region)
781772
# STS token - refresh if expired
782773
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
783-
sts_creds = self.generate_sts_session_token(
784-
self.transport_options.get('sts_role_arn'),
785-
self.transport_options.get('sts_token_timeout', 900))
786-
self.sts_expiration = sts_creds['Expiration']
787-
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
788-
region=q.get('region', self.region),
789-
access_key_id=sts_creds['AccessKeyId'],
790-
secret_access_key=sts_creds['SecretAccessKey'],
791-
session_token=sts_creds['SessionToken'],
792-
)
793-
return c
774+
return self._new_predefined_queue_client_with_sts_session(queue, region)
794775
else: # STS token - ruse existing
776+
if queue not in self._predefined_queue_clients:
777+
return self._new_predefined_queue_client_with_sts_session(queue, region)
795778
return self._predefined_queue_clients[queue]
796779

780+
def _new_predefined_queue_client_with_sts_session(self, queue, region):
781+
sts_creds = self.generate_sts_session_token(
782+
self.transport_options.get('sts_role_arn'),
783+
self.transport_options.get('sts_token_timeout', 900))
784+
self.sts_expiration = sts_creds['Expiration']
785+
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
786+
region=region,
787+
access_key_id=sts_creds['AccessKeyId'],
788+
secret_access_key=sts_creds['SecretAccessKey'],
789+
session_token=sts_creds['SessionToken'],
790+
)
791+
return c
792+
797793
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
798794
sts_client = boto3.client('sts')
799795
sts_policy = sts_client.assume_role(

t/unit/transport/test_SQS.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,34 @@ def test_sts_session_not_expired(self):
996996
# Assert
997997
mock_generate_sts_session_token.assert_not_called()
998998

999+
def test_sts_session_with_multiple_predefined_queues(self):
1000+
connection = Connection(transport=SQS.Transport, transport_options={
1001+
'predefined_queues': example_predefined_queues,
1002+
'sts_role_arn': 'test::arn'
1003+
})
1004+
channel = connection.channel()
1005+
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
1006+
1007+
mock_generate_sts_session_token = Mock()
1008+
mock_new_sqs_client = Mock()
1009+
channel.new_sqs_client = mock_new_sqs_client
1010+
mock_generate_sts_session_token.return_value = {
1011+
'Expiration': datetime.utcnow() + timedelta(days=1),
1012+
'SessionToken': 123,
1013+
'AccessKeyId': 123,
1014+
'SecretAccessKey': 123
1015+
}
1016+
1017+
channel.generate_sts_session_token = mock_generate_sts_session_token
1018+
1019+
# Act
1020+
sqs(queue='queue-1')
1021+
sqs(queue='queue-2')
1022+
1023+
# Assert
1024+
mock_generate_sts_session_token.assert_called()
1025+
mock_new_sqs_client.assert_called()
1026+
9991027
def test_message_attribute(self):
10001028
message = 'my test message'
10011029
self.producer.publish(message, message_attributes={

0 commit comments

Comments
 (0)