Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions pkg/rabbitmqamqp/amqp_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"time"

"github.com/Azure/go-amqp"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"os"
"sync"
"time"
)

var _ = Describe("AMQP connection Test", func() {
Expand Down Expand Up @@ -117,12 +117,10 @@ var _ = Describe("AMQP connection Test", func() {
})

Describe("AMQP TLS connection should succeed with in different vhosts with Anonymous and External.", func() {
wg := &sync.WaitGroup{}
wg.Add(4)
DescribeTable("TLS connection should success in different vhosts ", func(virtualHost string, sasl amqp.SASLType) {
// Load CA cert
caCert, err := os.ReadFile("../../.ci/certs/ca_certificate.pem")
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())

// Create a CA certificate pool and add the CA certificate to it
caCertPool := x509.NewCertPool()
Expand All @@ -131,7 +129,7 @@ var _ = Describe("AMQP connection Test", func() {
// Load client cert
clientCert, err := tls.LoadX509KeyPair("../../.ci/certs/client_localhost_certificate.pem",
"../../.ci/certs/client_localhost_key.pem")
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())

// Create a TLS configuration
tlsConfig := &tls.Config{
Expand All @@ -146,34 +144,32 @@ var _ = Describe("AMQP connection Test", func() {
SASLType: sasl,
TLSConfig: tlsConfig,
})
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(connection).NotTo(BeNil())

// Close the connection
err = connection.Close(context.Background())
Expect(err).To(BeNil())
wg.Done()
Expect(err).ToNot(HaveOccurred())
},
Entry("with virtual host. External", "%2F", amqp.SASLTypeExternal("")),
Entry("with a not default virtual host. External", "tls", amqp.SASLTypeExternal("")),
Entry("with virtual host. Anonymous", "%2F", amqp.SASLTypeAnonymous()),
Entry("with a not default virtual host. Anonymous", "tls", amqp.SASLTypeAnonymous()),
Entry("default virtual host + External", "%2F", amqp.SASLTypeExternal("")),
Entry("non-default virtual host + External", "tls", amqp.SASLTypeExternal("")),
Entry("default virtual host + Anonymous", "%2F", amqp.SASLTypeAnonymous()),
Entry("non-default virtual host + Anonymous", "tls", amqp.SASLTypeAnonymous()),
)
go func() {
wg.Wait()
}()
})

Describe("AMQP TLS connection should fail with error.", func() {
tlsConfig := &tls.Config{}
Describe("AMQP TLS connection", func() {
It("should fail with error", func() {
tlsConfig := &tls.Config{}

// Dial the AMQP server with TLS configuration
connection, err := Dial(context.Background(), "amqps://does_not_exist:5671", &AmqpConnOptions{
TLSConfig: tlsConfig,
// Dial the AMQP server with TLS configuration
connection, err := Dial(context.Background(), "amqps://does_not_exist:5671", &AmqpConnOptions{
TLSConfig: tlsConfig,
})
Expect(connection).To(BeNil())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to open TLS connection"))
})
Expect(connection).To(BeNil())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to open TLS connection"))
})

})
178 changes: 92 additions & 86 deletions pkg/rabbitmqamqp/amqp_consumer_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package rabbitmqamqp
import (
"context"
"fmt"
"sync"
"time"

"github.com/Azure/go-amqp"
Expand Down Expand Up @@ -320,33 +319,43 @@ var _ = Describe("Consumer stream test", func() {
})

Describe("consumer should filter messages based on application properties", func() {
qName := generateName("consumer should filter messages based on application properties")
connection, err := Dial(context.Background(), "amqp://", nil)
Expect(err).To(BeNil())
queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{
Name: qName,
})
Expect(err).To(BeNil())
Expect(queueInfo).NotTo(BeNil())
var (
qName string
connection *AmqpConnection
)
BeforeEach(func() {
qName = generateName("consumer should filter messages based on application properties")
var err error
connection, err = Dial(context.Background(), "amqp://", nil)
Expect(err).ToNot(HaveOccurred())
queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{
Name: qName,
})
Expect(err).ToNot(HaveOccurred())
Expect(queueInfo).NotTo(BeNil())

publishMessagesWithMessageLogic(qName, "ignoredKey", 7, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"ignoredKey": "ignoredValue"}
})
publishMessagesWithMessageLogic(qName, "ignoredKey", 7, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"ignoredKey": "ignoredValue"}
})

publishMessagesWithMessageLogic(qName, "key1", 10, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"key1": "value1", "constFilterKey": "constFilterValue"}
})
publishMessagesWithMessageLogic(qName, "key1", 10, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"key1": "value1", "constFilterKey": "constFilterValue"}
})

publishMessagesWithMessageLogic(qName, "key2", 10, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"key2": "value2", "constFilterKey": "constFilterValue"}
})

publishMessagesWithMessageLogic(qName, "key2", 10, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"key2": "value2", "constFilterKey": "constFilterValue"}
publishMessagesWithMessageLogic(qName, "key3", 10, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"key3": "value3", "constFilterKey": "constFilterValue"}
})
})

publishMessagesWithMessageLogic(qName, "key3", 10, func(msg *amqp.Message) {
msg.ApplicationProperties = map[string]interface{}{"key3": "value3", "constFilterKey": "constFilterValue"}
AfterEach(func() {
Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(Succeed())
Expect(connection.Close(context.Background())).To(Succeed())
})

var wg sync.WaitGroup
wg.Add(3)
DescribeTable("consumer should filter messages based on application properties", func(key string, value any, label string) {

consumer, err := connection.NewConsumer(context.Background(), qName, &StreamConsumerOptions{
Expand Down Expand Up @@ -375,93 +384,96 @@ var _ = Describe("Consumer stream test", func() {
Expect(dc.Accept(context.Background())).To(BeNil())
}
Expect(consumer.Close(context.Background())).To(BeNil())
wg.Done()
},
Entry("key1 value1", "key1", "value1", "key1"),
Entry("key2 value2", "key2", "value2", "key2"),
Entry("key3 value3", "key3", "value3", "key3"),
)
go func() {
wg.Wait()
Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil())
Expect(connection.Close(context.Background())).To(BeNil())
}()

})

Describe("consumer should filter messages based on properties", func() {
/*
Test the consumer should filter messages based on properties
*/
// TODO: defer cleanup to delete the stream queue
qName := generateName("consumer should filter messages based on properties")
qName += time.Now().String()
connection, err := Dial(context.Background(), "amqp://", nil)
Expect(err).To(BeNil())
queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{
Name: qName,
})
Expect(err).To(BeNil())
Expect(queueInfo).NotTo(BeNil())
var (
qName string
connection *AmqpConnection
)

publishMessagesWithMessageLogic(qName, "MessageID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{MessageID: "MessageID"}
})
BeforeEach(func() {
qName = generateName("consumer should filter messages based on properties")
qName += time.Now().String()
var err error
connection, err = Dial(context.Background(), "amqp://", nil)
Expect(err).ToNot(HaveOccurred())
queueInfo, err := connection.Management().DeclareQueue(context.Background(), &StreamQueueSpecification{
Name: qName,
})
Expect(err).ToNot(HaveOccurred())
Expect(queueInfo).NotTo(BeNil())

publishMessagesWithMessageLogic(qName, "Subject", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{Subject: ptr("Subject")}
})
publishMessagesWithMessageLogic(qName, "MessageID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{MessageID: "MessageID"}
})

publishMessagesWithMessageLogic(qName, "ReplyTo", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ReplyTo: ptr("ReplyTo")}
})
publishMessagesWithMessageLogic(qName, "Subject", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{Subject: ptr("Subject")}
})

publishMessagesWithMessageLogic(qName, "ContentType", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ContentType: ptr("ContentType")}
})
publishMessagesWithMessageLogic(qName, "ReplyTo", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ReplyTo: ptr("ReplyTo")}
})

publishMessagesWithMessageLogic(qName, "ContentEncoding", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ContentEncoding: ptr("ContentEncoding")}
})
publishMessagesWithMessageLogic(qName, "ContentType", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ContentType: ptr("ContentType")}
})

publishMessagesWithMessageLogic(qName, "GroupID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{GroupID: ptr("GroupID")}
})
publishMessagesWithMessageLogic(qName, "ContentEncoding", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ContentEncoding: ptr("ContentEncoding")}
})

publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")}
})
publishMessagesWithMessageLogic(qName, "GroupID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{GroupID: ptr("GroupID")}
})

// GroupSequence
publishMessagesWithMessageLogic(qName, "GroupSequence", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{GroupSequence: ptr(uint32(137))}
})
publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")}
})

// ReplyToGroupID
publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")}
})
// GroupSequence
publishMessagesWithMessageLogic(qName, "GroupSequence", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{GroupSequence: ptr(uint32(137))}
})

// CreationTime
// ReplyToGroupID
publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")}
})

publishMessagesWithMessageLogic(qName, "CreationTime", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{CreationTime: ptr(createDateTime())}
})
// CreationTime

// AbsoluteExpiryTime
publishMessagesWithMessageLogic(qName, "CreationTime", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{CreationTime: ptr(createDateTime())}
})

publishMessagesWithMessageLogic(qName, "AbsoluteExpiryTime", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())}
})
// AbsoluteExpiryTime

// CorrelationID
publishMessagesWithMessageLogic(qName, "AbsoluteExpiryTime", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())}
})

// CorrelationID

publishMessagesWithMessageLogic(qName, "CorrelationID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{CorrelationID: "CorrelationID"}
publishMessagesWithMessageLogic(qName, "CorrelationID", 10, func(msg *amqp.Message) {
msg.Properties = &amqp.MessageProperties{CorrelationID: "CorrelationID"}
})
})

AfterEach(func() {
Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil())
Expect(connection.Close(context.Background())).To(BeNil())
})

var wg sync.WaitGroup
wg.Add(12)
DescribeTable("consumer should filter messages based on properties", func(properties *amqp.MessageProperties, label string) {

consumer, err := connection.NewConsumer(context.Background(), qName, &StreamConsumerOptions{
Expand Down Expand Up @@ -533,7 +545,6 @@ var _ = Describe("Consumer stream test", func() {
Expect(dc.Accept(context.Background())).To(BeNil())
}
Expect(consumer.Close(context.Background())).To(BeNil())
wg.Done()
},
Entry("MessageID", &amqp.MessageProperties{MessageID: "MessageID"}, "MessageID"),
Entry("Subject", &amqp.MessageProperties{Subject: ptr("Subject")}, "Subject"),
Expand All @@ -548,11 +559,6 @@ var _ = Describe("Consumer stream test", func() {
Entry("AbsoluteExpiryTime", &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())}, "AbsoluteExpiryTime"),
Entry("CorrelationID", &amqp.MessageProperties{CorrelationID: "CorrelationID"}, "CorrelationID"),
)
go func() {
wg.Wait()
Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil())
Expect(connection.Close(context.Background())).To(BeNil())
}()
})

It("SQL filter consumer", func() {
Expand Down
13 changes: 7 additions & 6 deletions pkg/rabbitmqamqp/amqp_exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rabbitmqamqp
import (
"context"
"errors"

"github.com/Azure/go-amqp"
)

Expand All @@ -23,14 +24,14 @@ type AmqpExchange struct {
management *AmqpManagement
arguments map[string]any
isAutoDelete bool
exchangeType ExchangeType
exchangeType TExchangeType
}

func newAmqpExchange(management *AmqpManagement, name string) *AmqpExchange {
return &AmqpExchange{management: management,
name: name,
arguments: make(map[string]any),
exchangeType: ExchangeType{Type: Direct},
exchangeType: Direct,
}
}

Expand All @@ -46,7 +47,7 @@ func (e *AmqpExchange) Declare(ctx context.Context) (*AmqpExchangeInfo, error) {
kv := make(map[string]any)
kv["auto_delete"] = e.isAutoDelete
kv["durable"] = true
kv["type"] = e.exchangeType.String()
kv["type"] = string(e.exchangeType)
if e.arguments != nil {
kv["arguments"] = e.arguments
}
Expand Down Expand Up @@ -78,14 +79,14 @@ func (e *AmqpExchange) Delete(ctx context.Context) error {
return err
}

func (e *AmqpExchange) ExchangeType(exchangeType ExchangeType) {
if len(exchangeType.Type) > 0 {
func (e *AmqpExchange) ExchangeType(exchangeType TExchangeType) {
if len(exchangeType) > 0 {
e.exchangeType = exchangeType
}
}

func (e *AmqpExchange) GetExchangeType() TExchangeType {
return e.exchangeType.Type
return e.exchangeType
}

func (e *AmqpExchange) Name() string {
Expand Down
Loading