diff --git a/.ci/ubuntu/gha-setup.sh b/.ci/ubuntu/gha-setup.sh index 1118c97..8ed0fb6 100755 --- a/.ci/ubuntu/gha-setup.sh +++ b/.ci/ubuntu/gha-setup.sh @@ -7,7 +7,7 @@ set -o xtrace script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" readonly script_dir echo "[INFO] script_dir: '$script_dir'" -readonly rabbitmq_image=rabbitmq:4.1.0-beta.4-management-alpine +readonly rabbitmq_image=rabbitmq:4.1-management-alpine readonly docker_name_prefix='rabbitmq-amqp-go-client' diff --git a/.github/workflows/build-test.yaml b/.github/workflows/build-test.yaml index 7284091..fad80fc 100644 --- a/.github/workflows/build-test.yaml +++ b/.github/workflows/build-test.yaml @@ -9,7 +9,9 @@ jobs: strategy: fail-fast: true matrix: - go: [ '1.22'] + go: + - stable + - oldstable steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 diff --git a/.gitignore b/.gitignore index eecfdeb..6a4c85b 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,15 @@ go.work.sum coverage.txt .DS_Store .ci/ubuntu/log/rabbitmq.log + +# Visual Studio Code +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets +!*.code-workspace + +# Built Visual Studio Code Extensions +*.vsix \ No newline at end of file diff --git a/docs/examples/rpc_echo_server/main.go b/docs/examples/rpc_echo_server/main.go new file mode 100644 index 0000000..db93692 --- /dev/null +++ b/docs/examples/rpc_echo_server/main.go @@ -0,0 +1,106 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "os/signal" + + "github.com/Azure/go-amqp" + "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmqamqp" +) + +type echoRpcServer struct { + conn *rabbitmqamqp.AmqpConnection + server rabbitmqamqp.RpcServer +} + +func (s *echoRpcServer) stop(ctx context.Context) { + s.server.Close(ctx) + s.conn.Close(ctx) +} + +func newEchoRpcServer(conn *rabbitmqamqp.AmqpConnection) *echoRpcServer { + conn.Management().DeclareQueue(context.TODO(), &rabbitmqamqp.QuorumQueueSpecification{ + Name: rpcServerQueueName, + }) + srv, err := conn.NewRpcServer(context.TODO(), rabbitmqamqp.RpcServerOptions{ + RequestQueue: rpcServerQueueName, + Handler: func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { + return request, nil + }, + }) + if err != nil { + panic(err) + } + return &echoRpcServer{ + conn: conn, + server: srv, + } +} + +const rpcServerQueueName = "rpc-queue" + +func main() { + // Dial rabbit for RPC server connection + srvConn, err := rabbitmqamqp.Dial(context.TODO(), "amqp://localhost:5672", nil) + if err != nil { + panic(err) + } + + srv := newEchoRpcServer(srvConn) + + // Dial rabbit for RPC client connection + clientConn, err := rabbitmqamqp.Dial(context.TODO(), "amqp://localhost:5672", nil) + if err != nil { + panic(err) + } + + rpcClient, err := clientConn.NewRpcClient(context.TODO(), &rabbitmqamqp.RpcClientOptions{ + RequestQueueName: rpcServerQueueName, + }) + if err != nil { + panic(err) + } + + // Set up a channel to listen for OS signals + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) // Listen for Ctrl+C + + // Goroutine to handle graceful shutdown + go func() { + <-sigs // Wait for Ctrl+C + fmt.Println("\nReceived Ctrl+C, gracefully shutting down...") + srv.stop(context.TODO()) + _ = clientConn.Close(context.TODO()) + _ = srvConn.Close(context.TODO()) + os.Exit(0) + }() + + reader := bufio.NewReader(os.Stdin) + fmt.Println("Type a message and press Enter to send (Ctrl+C to quit):") + + for { + fmt.Print("Enter message: ") + input, _ := reader.ReadString('\n') + // Remove newline character from input + message := input[:len(input)-1] + + if message == "" { + continue + } + + resp, err := rpcClient.Publish(context.TODO(), amqp.NewMessage([]byte(message))) + if err != nil { + fmt.Printf("Error calling RPC: %v\n", err) + continue + } + m, ok := <-resp + if !ok { + fmt.Println("timed out waiting for response") + continue + } + fmt.Printf("response: %s\n", m.GetData()) + } +} diff --git a/go.mod b/go.mod index ea3c303..335adbe 100644 --- a/go.mod +++ b/go.mod @@ -2,22 +2,25 @@ module github.com/rabbitmq/rabbitmq-amqp-go-client go 1.23.0 +toolchain go1.24.5 + require ( github.com/Azure/go-amqp v1.4.0 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 - github.com/onsi/ginkgo/v2 v2.22.1 - github.com/onsi/gomega v1.36.2 + github.com/onsi/ginkgo/v2 v2.23.4 + github.com/onsi/gomega v1.38.0 ) require ( github.com/go-logr/logr v1.4.2 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/google/go-cmp v0.6.0 // indirect - github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect - golang.org/x/net v0.38.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect - golang.org/x/tools v0.28.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect + go.uber.org/automaxprocs v1.6.0 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect + golang.org/x/tools v0.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f01925d..66c840d 100644 --- a/go.sum +++ b/go.sum @@ -10,31 +10,40 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= -github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= +github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/onsi/ginkgo/v2 v2.22.1 h1:QW7tbJAUDyVDVOM5dFa7qaybo+CRfR7bemlQUN6Z8aM= -github.com/onsi/ginkgo/v2 v2.22.1/go.mod h1:S6aTpoRsSq2cZOd+pssHAlKW/Q/jZt6cPrPlnj4a1xM= -github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= -github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= +github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= +github.com/onsi/gomega v1.38.0 h1:c/WX+w8SLAinvuKKQFh77WEucCnPk4j2OTUr7lt7BeY= +github.com/onsi/gomega v1.38.0/go.mod h1:OcXcwId0b9QsE7Y49u+BTrL4IdKOBOKnD6VQNTJEB6o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= -golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= -golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= -google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= -google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= +go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/rabbitmqamqp/address_test.go b/pkg/rabbitmqamqp/address_test.go index f629193..fbf2547 100644 --- a/pkg/rabbitmqamqp/address_test.go +++ b/pkg/rabbitmqamqp/address_test.go @@ -15,7 +15,7 @@ var _ = Describe("address builder test ", func() { Expect(err.Error()).To(Equal(expectedErr)) }, Entry("when both exchange and queue are set", - stringPtr("my_exchange"), nil, stringPtr("my_queue"), + ptr("my_exchange"), nil, ptr("my_queue"), "exchange and queue cannot be set together"), Entry("when neither exchange nor queue is set", nil, nil, nil, @@ -32,13 +32,13 @@ var _ = Describe("address builder test ", func() { Expect(address).To(Equal(expected)) }, Entry("with exchange and key", - stringPtr("my_exchange"), stringPtr("my_key"), + ptr("my_exchange"), ptr("my_key"), "/exchanges/my_exchange/my_key"), Entry("with exchange only", - stringPtr("my_exchange"), nil, + ptr("my_exchange"), nil, "/exchanges/my_exchange"), Entry("with special characters", - stringPtr("my_ exchange/()"), stringPtr("my_key "), + ptr("my_ exchange/()"), ptr("my_key "), "/exchanges/my_%20exchange%2F%28%29/my_key%20"), ) }) diff --git a/pkg/rabbitmqamqp/amqp_connection.go b/pkg/rabbitmqamqp/amqp_connection.go index 6583579..84bdadc 100644 --- a/pkg/rabbitmqamqp/amqp_connection.go +++ b/pkg/rabbitmqamqp/amqp_connection.go @@ -3,15 +3,19 @@ package rabbitmqamqp import ( "context" "crypto/tls" + "errors" "fmt" - "github.com/Azure/go-amqp" - "github.com/google/uuid" "math/rand" "sync" "sync/atomic" "time" + + "github.com/Azure/go-amqp" + "github.com/google/uuid" ) +var ErrConnectionClosed = errors.New("connection is closed") + type AmqpAddress struct { // the address of the AMQP server // it is in the form of amqp://: @@ -120,6 +124,8 @@ type AmqpConnection struct { session *amqp.Session refMap *sync.Map entitiesTracker *entitiesTracker + mutex sync.RWMutex + closed bool } func (a *AmqpConnection) Properties() map[string]any { @@ -168,6 +174,146 @@ func (a *AmqpConnection) NewConsumer(ctx context.Context, queueName string, opti return newConsumer(ctx, a, destinationAdd, options) } +// NewRpcServer creates a new RPC server that processes requests from the +// specified queue. The requestQueue in options is mandatory, while other +// fields are optional and will use defaults if not provided. +func (a *AmqpConnection) NewRpcServer(ctx context.Context, options RpcServerOptions) (RpcServer, error) { + if err := options.validate(); err != nil { + return nil, fmt.Errorf("rpc server options validation: %w", err) + } + + // Create consumer for receiving requests + // consumer, err := a.NewConsumer(ctx, options.RequestQueue, nil) + consumer, err := a.NewConsumer(ctx, options.RequestQueue, &ConsumerOptions{InitialCredits: -1}) + if err != nil { + return nil, fmt.Errorf("failed to create consumer: %w", err) + } + consumer.issueCredits(1) + + // Create publisher for sending replies + publisher, err := a.NewPublisher(ctx, nil, nil) + if err != nil { + consumer.Close(ctx) // cleanup consumer on failure + return nil, fmt.Errorf("failed to create publisher: %w", err) + } + + // Set defaults for optional fields + handler := options.Handler + if handler == nil { + handler = noOpHandler + } + + correlationIdExtractor := options.CorrelationIdExtractor + if correlationIdExtractor == nil { + correlationIdExtractor = defaultCorrelationIdExtractor + } + + replyPostProcessor := options.ReplyPostProcessor + if replyPostProcessor == nil { + replyPostProcessor = defaultReplyPostProcessor + } + + server := &amqpRpcServer{ + requestHandler: handler, + requestQueue: options.RequestQueue, + publisher: publisher, + consumer: consumer, + correlationIdExtractor: correlationIdExtractor, + replyPostProcessor: replyPostProcessor, + } + go server.handle() + + return server, nil +} + +// NewRpcClient creates a new RPC client that sends requests to the specified queue +// and receives replies on a dynamically created reply queue. +func (a *AmqpConnection) NewRpcClient(ctx context.Context, options *RpcClientOptions) (RpcClient, error) { + if options == nil { + return nil, fmt.Errorf("options cannot be nil") + } + if options.RequestQueueName == "" { + return nil, fmt.Errorf("requestQueueName is mandatory") + } + + // Create publisher for sending requests + requestQueue := &QueueAddress{ + Queue: options.RequestQueueName, + } + publisher, err := a.NewPublisher(ctx, requestQueue, nil) + if err != nil { + return nil, fmt.Errorf("failed to create publisher: %w", err) + } + + replyQueueName := options.ReplyToQueueName + if len(replyQueueName) == 0 { + replyQueueName = generateNameWithDefaultPrefix() + } + + // Declare reply queue as exclusive, auto-delete classic queue + q, err := a.management.DeclareQueue(ctx, &ClassicQueueSpecification{ + Name: replyQueueName, + IsExclusive: true, + IsAutoDelete: true, + }) + if err != nil { + return nil, fmt.Errorf("failed to declare reply queue: %w", err) + } + + // Set defaults for optional fields + correlationIdSupplier := options.CorrelationIdSupplier + if correlationIdSupplier == nil { + correlationIdSupplier = newRandomUuidCorrelationIdSupplier() + } + + requestPostProcessor := options.RequestPostProcessor + if requestPostProcessor == nil { + requestPostProcessor = func(request *amqp.Message, correlationID any) *amqp.Message { + if request.Properties == nil { + request.Properties = &amqp.MessageProperties{} + } + request.Properties.MessageID = correlationID + return request + } + } + + requestTimeout := options.RequestTimeout + if requestTimeout == 0 { + requestTimeout = DefaultRpcRequestTimeout + } + + correlationIdExtractor := options.CorrelationIdExtractor + if correlationIdExtractor == nil { + correlationIdExtractor = defaultReplyCorrelationIdExtractor + } + + client := &amqpRpcClient{ + requestQueue: requestQueue, + replyToQueue: &QueueAddress{Queue: replyQueueName}, + publisher: publisher, + requestPostProcessor: requestPostProcessor, + correlationIdSupplier: correlationIdSupplier, + correlationIdExtractor: correlationIdExtractor, + requestTimeout: requestTimeout, + pendingRequests: make(map[any]*outstandingRequest), + done: make(chan struct{}), + } + + // Create consumer for receiving replies + consumer, err := a.NewConsumer(ctx, q.Name(), nil) + if err != nil { + _ = publisher.Close(ctx) // cleanup publisher on failure + return nil, fmt.Errorf("failed to create consumer: %w", err) + } + + client.consumer = consumer + + go client.messageReceivedHandler() + go client.requestTimeoutTask() + + return client, nil +} + // Dial connect to the AMQP 1.0 server using the provided connectionSettings // Returns a pointer to the new AmqpConnection if successful else an error. func Dial(ctx context.Context, address string, connOptions *AmqpConnOptions) (*AmqpConnection, error) { @@ -238,6 +384,12 @@ func validateOptions(connOptions *AmqpConnOptions) (*AmqpConnOptions, error) { // using the provided connectionSettings and the AMQPLite library. // Setups the connection and the management interface. func (a *AmqpConnection) open(ctx context.Context, address string, connOptions *AmqpConnOptions) error { + a.mutex.Lock() + defer a.mutex.Unlock() + + if a.closed { + return ErrConnectionClosed + } // random pick and extract one address to use for connection var azureConnection *amqp.Conn @@ -315,7 +467,6 @@ func (a *AmqpConnection) open(ctx context.Context, address string, connOptions * return nil } func (a *AmqpConnection) maybeReconnect() { - if !a.amqpConnOptions.RecoveryConfiguration.ActiveRecovery { Info("Recovery is disabled, closing connection", "ID", a.Id()) return @@ -326,7 +477,6 @@ func (a *AmqpConnection) maybeReconnect() { maxDelay := 1 * time.Minute for attempt := 1; attempt <= a.amqpConnOptions.RecoveryConfiguration.MaxReconnectAttempts; attempt++ { - ///wait for before reconnecting // add some random milliseconds to the wait time to avoid thundering herd // the random time is between 0 and 500 milliseconds @@ -350,6 +500,12 @@ func (a *AmqpConnection) maybeReconnect() { a.lifeCycle.SetState(&StateOpen{}) return } + + if errors.Is(err, ErrConnectionClosed) { + Info("Connection was closed during reconnect, aborting.", "ID", a.Id()) + return + } + baseDelay *= 2 Error("Reconnection attempt failed", "attempt", attempt, "error", err, "ID", a.Id()) } @@ -407,6 +563,13 @@ Close closes the connection to the AMQP 1.0 server and the management interface. All the publishers and consumers are closed as well. */ func (a *AmqpConnection) Close(ctx context.Context) error { + a.mutex.Lock() + if a.closed { + a.mutex.Unlock() + return nil + } + a.closed = true + defer a.mutex.Unlock() // the status closed (lifeCycle.SetState(&StateClosed{error: nil})) is not set here // it is set in the connection.Done() channel // the channel is called anyway diff --git a/pkg/rabbitmqamqp/amqp_consumer.go b/pkg/rabbitmqamqp/amqp_consumer.go index 5207f11..8bbdd76 100644 --- a/pkg/rabbitmqamqp/amqp_consumer.go +++ b/pkg/rabbitmqamqp/amqp_consumer.go @@ -3,9 +3,10 @@ package rabbitmqamqp import ( "context" "fmt" + "sync/atomic" + "github.com/Azure/go-amqp" "github.com/google/uuid" - "sync/atomic" ) type DeliveryContext struct { @@ -65,6 +66,14 @@ func (dc *DeliveryContext) RequeueWithAnnotations(ctx context.Context, annotatio }) } +type consumerState byte + +const ( + consumerStateRunning consumerState = iota + consumerStatePausing + consumerStatePaused +) + type Consumer struct { receiver atomic.Pointer[amqp.Receiver] connection *AmqpConnection @@ -79,6 +88,8 @@ type Consumer struct { For the AMQP queues it is just ignored. */ currentOffset int64 + + state consumerState } func (c *Consumer) Id() string { @@ -148,3 +159,40 @@ func (c *Consumer) Close(ctx context.Context) error { c.connection.entitiesTracker.removeConsumer(c) return c.receiver.Load().Close(ctx) } + +// pause drains the credits of the receiver and stops issuing new credits. +func (c *Consumer) pause(ctx context.Context) error { + if c.state == consumerStatePaused || c.state == consumerStatePausing { + return nil + } + c.state = consumerStatePausing + err := c.receiver.Load().DrainCredit(ctx, nil) + if err != nil { + c.state = consumerStateRunning + return fmt.Errorf("error draining credits: %w", err) + } + c.state = consumerStatePaused + return nil +} + +// unpause requests new credits using the initial credits value of the options. +func (c *Consumer) unpause(credits uint32) error { + if c.state == consumerStateRunning { + return nil + } + err := c.receiver.Load().IssueCredit(credits) + if err != nil { + return fmt.Errorf("error issuing credits: %w", err) + } + c.state = consumerStateRunning + return nil +} + +func (c *Consumer) isPausedOrPausing() bool { + return c.state != consumerStateRunning +} + +// issueCredits issues more credits on the receiver. +func (c *Consumer) issueCredits(credits uint32) error { + return c.receiver.Load().IssueCredit(credits) +} diff --git a/pkg/rabbitmqamqp/amqp_consumer_stream_test.go b/pkg/rabbitmqamqp/amqp_consumer_stream_test.go index ecd3b02..3372996 100644 --- a/pkg/rabbitmqamqp/amqp_consumer_stream_test.go +++ b/pkg/rabbitmqamqp/amqp_consumer_stream_test.go @@ -3,12 +3,13 @@ package rabbitmqamqp import ( "context" "fmt" + "sync" + "time" + "github.com/Azure/go-amqp" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" testhelper "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/test-helper" - "sync" - "time" ) var _ = Describe("Consumer stream test", func() { @@ -392,6 +393,7 @@ var _ = Describe("Consumer stream test", func() { /* Test the consumer should filter messages based on properties */ + // TODO: defer cleanup to delete the stream queue qName := generateName("consumer should filter messages based on properties") qName += time.Now().String() connection, err := Dial(context.Background(), "amqp://", nil) @@ -407,49 +409,49 @@ var _ = Describe("Consumer stream test", func() { }) publishMessagesWithMessageLogic(qName, "Subject", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{Subject: stringPtr("Subject")} + msg.Properties = &amqp.MessageProperties{Subject: ptr("Subject")} }) publishMessagesWithMessageLogic(qName, "ReplyTo", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ReplyTo: stringPtr("ReplyTo")} + msg.Properties = &amqp.MessageProperties{ReplyTo: ptr("ReplyTo")} }) publishMessagesWithMessageLogic(qName, "ContentType", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ContentType: stringPtr("ContentType")} + msg.Properties = &amqp.MessageProperties{ContentType: ptr("ContentType")} }) publishMessagesWithMessageLogic(qName, "ContentEncoding", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ContentEncoding: stringPtr("ContentEncoding")} + msg.Properties = &amqp.MessageProperties{ContentEncoding: ptr("ContentEncoding")} }) publishMessagesWithMessageLogic(qName, "GroupID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{GroupID: stringPtr("GroupID")} + msg.Properties = &amqp.MessageProperties{GroupID: ptr("GroupID")} }) publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ReplyToGroupID: stringPtr("ReplyToGroupID")} + msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")} }) // GroupSequence publishMessagesWithMessageLogic(qName, "GroupSequence", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{GroupSequence: uint32Ptr(137)} + msg.Properties = &amqp.MessageProperties{GroupSequence: ptr(uint32(137))} }) // ReplyToGroupID publishMessagesWithMessageLogic(qName, "ReplyToGroupID", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{ReplyToGroupID: stringPtr("ReplyToGroupID")} + msg.Properties = &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")} }) // CreationTime publishMessagesWithMessageLogic(qName, "CreationTime", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{CreationTime: timePtr(createDateTime())} + msg.Properties = &amqp.MessageProperties{CreationTime: ptr(createDateTime())} }) // AbsoluteExpiryTime publishMessagesWithMessageLogic(qName, "AbsoluteExpiryTime", 10, func(msg *amqp.Message) { - msg.Properties = &amqp.MessageProperties{AbsoluteExpiryTime: timePtr(createDateTime())} + msg.Properties = &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())} }) // CorrelationID @@ -534,16 +536,16 @@ var _ = Describe("Consumer stream test", func() { wg.Done() }, Entry("MessageID", &amqp.MessageProperties{MessageID: "MessageID"}, "MessageID"), - Entry("Subject", &amqp.MessageProperties{Subject: stringPtr("Subject")}, "Subject"), - Entry("ReplyTo", &amqp.MessageProperties{ReplyTo: stringPtr("ReplyTo")}, "ReplyTo"), - Entry("ContentType", &amqp.MessageProperties{ContentType: stringPtr("ContentType")}, "ContentType"), - Entry("ContentEncoding", &amqp.MessageProperties{ContentEncoding: stringPtr("ContentEncoding")}, "ContentEncoding"), - Entry("GroupID", &amqp.MessageProperties{GroupID: stringPtr("GroupID")}, "GroupID"), - Entry("ReplyToGroupID", &amqp.MessageProperties{ReplyToGroupID: stringPtr("ReplyToGroupID")}, "ReplyToGroupID"), - Entry("GroupSequence", &amqp.MessageProperties{GroupSequence: uint32Ptr(137)}, "GroupSequence"), - Entry("ReplyToGroupID", &amqp.MessageProperties{ReplyToGroupID: stringPtr("ReplyToGroupID")}, "ReplyToGroupID"), - Entry("CreationTime", &amqp.MessageProperties{CreationTime: timePtr(createDateTime())}, "CreationTime"), - Entry("AbsoluteExpiryTime", &amqp.MessageProperties{AbsoluteExpiryTime: timePtr(createDateTime())}, "AbsoluteExpiryTime"), + Entry("Subject", &amqp.MessageProperties{Subject: ptr("Subject")}, "Subject"), + Entry("ReplyTo", &amqp.MessageProperties{ReplyTo: ptr("ReplyTo")}, "ReplyTo"), + Entry("ContentType", &amqp.MessageProperties{ContentType: ptr("ContentType")}, "ContentType"), + Entry("ContentEncoding", &amqp.MessageProperties{ContentEncoding: ptr("ContentEncoding")}, "ContentEncoding"), + Entry("GroupID", &amqp.MessageProperties{GroupID: ptr("GroupID")}, "GroupID"), + Entry("ReplyToGroupID", &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")}, "ReplyToGroupID"), + Entry("GroupSequence", &amqp.MessageProperties{GroupSequence: ptr(uint32(137))}, "GroupSequence"), + Entry("ReplyToGroupID", &amqp.MessageProperties{ReplyToGroupID: ptr("ReplyToGroupID")}, "ReplyToGroupID"), + Entry("CreationTime", &amqp.MessageProperties{CreationTime: ptr(createDateTime())}, "CreationTime"), + Entry("AbsoluteExpiryTime", &amqp.MessageProperties{AbsoluteExpiryTime: ptr(createDateTime())}, "AbsoluteExpiryTime"), Entry("CorrelationID", &amqp.MessageProperties{CorrelationID: "CorrelationID"}, "CorrelationID"), ) go func() { diff --git a/pkg/rabbitmqamqp/amqp_consumer_test.go b/pkg/rabbitmqamqp/amqp_consumer_test.go index 6a5bd73..2a2b686 100644 --- a/pkg/rabbitmqamqp/amqp_consumer_test.go +++ b/pkg/rabbitmqamqp/amqp_consumer_test.go @@ -2,6 +2,7 @@ package rabbitmqamqp import ( "context" + "github.com/Azure/go-amqp" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -161,5 +162,49 @@ var _ = Describe("NewConsumer tests", func() { Expect(connection.Management().DeleteQueue(context.Background(), qName)).To(BeNil()) Expect(connection.Close(context.Background())).To(BeNil()) }) +}) + +var _ = Describe("Consumer pause and unpause", func() { + It("pauses and unpauses the consumer", func(ctx SpecContext) { + // setup + qName := CurrentSpecReport().LeafNodeText + c, err := declareQueueAndConnection(qName) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + _ = c.Close(ctx) + }) + + publishMessages(qName, 1) + consumer, err := c.NewConsumer(ctx, qName, &ConsumerOptions{InitialCredits: -1}) + Expect(err).ToNot(HaveOccurred()) + Expect(consumer.receiver.Load().IssueCredit(1)).To(Succeed()) + + By("receiving a message when unpaused") + dc, err := consumer.Receive(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(dc.Accept(ctx)).To(Succeed()) + + By("not receiving any new messages after pausing") + Expect(consumer.pause(ctx)).To(Succeed()) + Eventually(consumer.isPausedOrPausing).Should(BeTrue(), "expected consumer to be paused or pausing") + publishMessages(qName, 10) + // have to assert again because pause may enocunter an error and not complete the pause operation + Eventually(consumer.isPausedOrPausing).Should(BeTrue(), "expected consumer to be paused") + + rCtx, rCancel := context.WithTimeout(ctx, 200*time.Millisecond) + DeferCleanup(rCancel) + _, err = consumer.Receive(rCtx) + Expect(err).To(MatchError(context.DeadlineExceeded)) + + By("receiving a new message after unpausing") + Expect(consumer.unpause(10)).To(Succeed()) + Eventually(consumer.isPausedOrPausing).Should(BeFalse(), "expected consumer to be unpaused") + + for i := 0; i < 10; i++ { + dc, err = consumer.Receive(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(dc.Accept(ctx)).To(Succeed()) + } + }, SpecTimeout(time.Second*10)) }) diff --git a/pkg/rabbitmqamqp/common.go b/pkg/rabbitmqamqp/common.go index 2f29e81..cd56803 100644 --- a/pkg/rabbitmqamqp/common.go +++ b/pkg/rabbitmqamqp/common.go @@ -4,8 +4,11 @@ import ( "crypto/md5" "encoding/base64" "fmt" - "github.com/google/uuid" "strings" + "time" + + "github.com/Azure/go-amqp" + "github.com/google/uuid" ) // public consts @@ -61,3 +64,28 @@ func isStringNilOrEmpty(str *string) bool { return str == nil || len(*str) == 0 } + +func callAndMaybeRetry(fn func() error, delays []time.Duration) error { + var err error + for i, delay := range delays { + err = fn() + if err == nil { + return nil + } + Error("Retrying operation", "attempt", i+1, "error", err) + if i < len(delays)-1 { // Don't sleep after the last attempt + time.Sleep(delay) + } + } + return fmt.Errorf("failed after %d attempts: %w", len(delays), err) +} + +// setToProperty sets the To property of the message m to the value of replyTo. +// If the message has no properties, it creates a new properties object. +// This function modifies the message in place. +func setToProperty(m *amqp.Message, replyTo *string) { + if m.Properties == nil { + m.Properties = &amqp.MessageProperties{} + } + m.Properties.To = replyTo +} diff --git a/pkg/rabbitmqamqp/example_rpc_custom_test.go b/pkg/rabbitmqamqp/example_rpc_custom_test.go new file mode 100644 index 0000000..2e4822c --- /dev/null +++ b/pkg/rabbitmqamqp/example_rpc_custom_test.go @@ -0,0 +1,108 @@ +package rabbitmqamqp_test + +import ( + "context" + "fmt" + + "github.com/Azure/go-amqp" + "github.com/google/uuid" + "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmqamqp" +) + +const ( + rpcServerQueueNameCustom = "rpc-queue-custom" + correlationIDHeader = "x-correlation-id" +) + +type customCorrelationIDSupplier struct{} + +func (s *customCorrelationIDSupplier) Get() any { + return uuid.New().String() +} + +func Example_customCorrelationId() { + // Dial rabbit for RPC server connection + srvConn, err := rabbitmqamqp.Dial(context.TODO(), "amqp://localhost:5672", nil) + if err != nil { + panic(err) + } + defer srvConn.Close(context.Background()) + + _, err = srvConn.Management().DeclareQueue(context.TODO(), &rabbitmqamqp.QuorumQueueSpecification{ + Name: rpcServerQueueNameCustom, + }) + if err != nil { + panic(err) + } + + server, err := srvConn.NewRpcServer(context.TODO(), rabbitmqamqp.RpcServerOptions{ + RequestQueue: rpcServerQueueNameCustom, + Handler: func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { + fmt.Printf("Received: %s\n", request.GetData()) + return request, nil + }, + CorrelationIdExtractor: func(message *amqp.Message) any { + if message.ApplicationProperties == nil { + panic("application properties are missing") + } + return message.ApplicationProperties[correlationIDHeader] + }, + ReplyPostProcessor: func(reply *amqp.Message, correlationID any) *amqp.Message { + if reply.ApplicationProperties == nil { + reply.ApplicationProperties = make(map[string]interface{}) + } + reply.ApplicationProperties[correlationIDHeader] = correlationID + return reply + }, + }) + if err != nil { + panic(err) + } + defer server.Close(context.Background()) + + // Dial rabbit for RPC client connection + clientConn, err := rabbitmqamqp.Dial(context.TODO(), "amqp://localhost:5672", nil) + if err != nil { + panic(err) + } + defer clientConn.Close(context.Background()) + + rpcClient, err := clientConn.NewRpcClient(context.TODO(), &rabbitmqamqp.RpcClientOptions{ + RequestQueueName: rpcServerQueueNameCustom, + CorrelationIdSupplier: &customCorrelationIDSupplier{}, + CorrelationIdExtractor: func(message *amqp.Message) any { + if message.ApplicationProperties == nil { + panic("application properties are missing") + } + return message.ApplicationProperties[correlationIDHeader] + }, + RequestPostProcessor: func(request *amqp.Message, correlationID any) *amqp.Message { + if request.ApplicationProperties == nil { + request.ApplicationProperties = make(map[string]interface{}) + } + request.ApplicationProperties[correlationIDHeader] = correlationID + return request + }, + }) + if err != nil { + panic(err) + } + defer rpcClient.Close(context.Background()) + + message := "hello world" + resp, err := rpcClient.Publish(context.TODO(), amqp.NewMessage([]byte(message))) + if err != nil { + fmt.Printf("Error calling RPC: %v\n", err) + return + } + + m, ok := <-resp + if !ok { + fmt.Println("timed out waiting for response") + return + } + fmt.Printf("Response: %s\n", m.GetData()) + // Output: + // Received: hello world + // Response: hello world +} diff --git a/pkg/rabbitmqamqp/example_rpc_test.go b/pkg/rabbitmqamqp/example_rpc_test.go new file mode 100644 index 0000000..96193aa --- /dev/null +++ b/pkg/rabbitmqamqp/example_rpc_test.go @@ -0,0 +1,107 @@ +package rabbitmqamqp_test + +import ( + "bufio" + "context" + "fmt" + "os" + "os/signal" + + "github.com/Azure/go-amqp" + "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmqamqp" +) + +type echoRpcServer struct { + conn *rabbitmqamqp.AmqpConnection + server rabbitmqamqp.RpcServer +} + +func (s *echoRpcServer) stop(ctx context.Context) { + s.server.Close(ctx) + s.conn.Close(ctx) +} + +func newEchoRpcServer(conn *rabbitmqamqp.AmqpConnection) *echoRpcServer { + conn.Management().DeclareQueue(context.TODO(), &rabbitmqamqp.QuorumQueueSpecification{ + Name: rpcServerQueueName, + }) + srv, err := conn.NewRpcServer(context.TODO(), rabbitmqamqp.RpcServerOptions{ + RequestQueue: rpcServerQueueName, + Handler: func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { + fmt.Printf("echo: %s\n", request.GetData()) + return request, nil + }, + }) + if err != nil { + panic(err) + } + return &echoRpcServer{ + conn: conn, + server: srv, + } +} + +const rpcServerQueueName = "rpc-queue" + +func Example() { + // Dial rabbit for RPC server connection + srvConn, err := rabbitmqamqp.Dial(context.TODO(), "amqp://localhost:5672", nil) + if err != nil { + panic(err) + } + + srv := newEchoRpcServer(srvConn) + + // Dial rabbit for RPC client connection + clientConn, err := rabbitmqamqp.Dial(context.TODO(), "amqp://localhost:5672", nil) + if err != nil { + panic(err) + } + + rpcClient, err := clientConn.NewRpcClient(context.TODO(), &rabbitmqamqp.RpcClientOptions{ + RequestQueueName: rpcServerQueueName, + }) + if err != nil { + panic(err) + } + + // Set up a channel to listen for OS signals + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) // Listen for Ctrl+C + + // Goroutine to handle graceful shutdown + go func() { + <-sigs // Wait for Ctrl+C + fmt.Println("\nReceived Ctrl+C, gracefully shutting down...") + srv.stop(context.TODO()) + _ = clientConn.Close(context.TODO()) + _ = srvConn.Close(context.TODO()) + os.Exit(0) + }() + + reader := bufio.NewReader(os.Stdin) + fmt.Println("Type a message and press Enter to send (Ctrl+C to quit):") + + for { + fmt.Print("Enter message: ") + input, _ := reader.ReadString('\n') + // Remove newline character from input + message := input[:len(input)-1] + + if message == "" { + continue + } + + resp, err := rpcClient.Publish(context.TODO(), amqp.NewMessage([]byte(message))) + if err != nil { + fmt.Printf("Error calling RPC: %v\n", err) + continue + } + m, ok := <-resp + if !ok { + fmt.Println("timed out waiting for response") + continue + } + fmt.Printf("response: %s\n", m.GetData()) + } +} diff --git a/pkg/rabbitmqamqp/log.go b/pkg/rabbitmqamqp/log.go index a2ecde5..7e71456 100644 --- a/pkg/rabbitmqamqp/log.go +++ b/pkg/rabbitmqamqp/log.go @@ -2,6 +2,10 @@ package rabbitmqamqp import "log/slog" +func SetSlogHandler(handler slog.Handler) { + slog.SetDefault(slog.New(handler)) +} + func Info(msg string, args ...any) { slog.Info(msg, args...) } diff --git a/pkg/rabbitmqamqp/pkg_suite_test.go b/pkg/rabbitmqamqp/pkg_suite_test.go index e074a1a..c9f4a9b 100644 --- a/pkg/rabbitmqamqp/pkg_suite_test.go +++ b/pkg/rabbitmqamqp/pkg_suite_test.go @@ -1,13 +1,16 @@ package rabbitmqamqp_test import ( + "log/slog" "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmqamqp" ) func TestPkg(t *testing.T) { + rabbitmqamqp.SetSlogHandler(slog.NewTextHandler(GinkgoWriter, &slog.HandlerOptions{Level: slog.LevelDebug})) RegisterFailHandler(Fail) RunSpecs(t, "Pkg Suite") } diff --git a/pkg/rabbitmqamqp/rpc_client.go b/pkg/rabbitmqamqp/rpc_client.go new file mode 100644 index 0000000..c977d73 --- /dev/null +++ b/pkg/rabbitmqamqp/rpc_client.go @@ -0,0 +1,300 @@ +package rabbitmqamqp + +import ( + "context" + "errors" + "fmt" + "maps" + "sync" + "time" + + "github.com/Azure/go-amqp" + "github.com/google/uuid" +) + +// RpcClient is an interface for making RPC (Remote Procedure Call) requests over AMQP. +// Implementations of this interface should handle the sending of requests and +// the receiving of corresponding replies, managing correlation IDs and timeouts. +// +// The default implementation provides the following behaviour: +// - Requests are published to a specified request queue. This queue must be pre-declared. +// - Replies are consumed from a dedicated reply-to queue. This queue is dynamically created +// by the client. +// - Correlation IDs are used to match requests with replies. The default implementation +// uses a random UUID as prefix and an auto-incrementing counter as suffix. +// The UUIDs are set as MessageID in the request message. +// - A request timeout mechanism is in place to handle unacknowledged replies. +// - Messages are pre-processed before publishing. The default implementation +// assigns the correlation ID to the MessageID property of the request message. +// - Replies are simply sent over the "callback" channel. +// +// Implementers should ensure that: +// - `Close` properly shuts down underlying resources like publishers and consumers. +// - `Message` provides a basic AMQP message structure for RPC requests. +// - `Publish` sends the request message and returns a channel that will receive +// the reply message, or be closed if a timeout occurs or the client is closed. +type RpcClient interface { + Close(context.Context) error + Message(body []byte) *amqp.Message + Publish(context.Context, *amqp.Message) (<-chan *amqp.Message, error) +} + +// CorrelationIdSupplier is an interface for providing correlation IDs for RPC requests. +// Implementations should generate unique identifiers for each request. +// The returned value from `Get()` should be an AMQP type, or a type that can be +// encoded into an AMQP message property (e.g., string, int, []byte, etc.). +type CorrelationIdSupplier interface { + Get() any +} + +type randomUuidCorrelationIdSupplier struct { + mu sync.Mutex + prefix string + count int +} + +func (c *randomUuidCorrelationIdSupplier) Get() any { + c.mu.Lock() + defer c.mu.Unlock() + s := fmt.Sprintf("%s-%d", c.prefix, c.count) + c.count += 1 + return s +} + +func newRandomUuidCorrelationIdSupplier() CorrelationIdSupplier { + u, err := uuid.NewRandom() + if err != nil { + panic(err) + } + return &randomUuidCorrelationIdSupplier{ + prefix: u.String(), + count: 0, + } +} + +var defaultReplyCorrelationIdExtractor CorrelationIdExtractor = func(message *amqp.Message) any { + if message.Properties == nil || message.Properties.CorrelationID == nil { + return nil + } + return message.Properties.CorrelationID +} + +// RequestPostProcessor is a function that modifies an AMQP message before it is sent +// as an RPC request. It receives the message about to be sent and the correlation ID +// generated for the request. Implementations must assign the correlation ID to a +// message property (e.g., `MessageID` or `CorrelationID`) and set the `ReplyTo` +// address for the reply queue. The function must return the modified message. +// +// The default `RequestPostProcessor` implementation (used when `RequestPostProcessor` +// is not explicitly set in `RpcClientOptions`) performs the following: +// - Assigns the `correlationID` to the `MessageID` property of the `amqp.Message`. +// - Sets the `ReplyTo` message property to a client-generated exclusive auto-delete queue. +type RequestPostProcessor func(request *amqp.Message, correlationID any) *amqp.Message + +var DefaultRpcRequestTimeout = 30 * time.Second + +// RpcClientOptions is a struct that contains the options for the RPC client. +// It is used to configure the RPC client. +type RpcClientOptions struct { + // The name of the queue to send requests to. This queue must exist. + // + // Mandatory. + RequestQueueName string + // The name of the queue to receive replies from. + // + // Optional. If not set, a dedicated reply-to queue will be created for each request. + ReplyToQueueName string + // Generator of correlation IDs for requests. Each correlationID generated must be unique. + // + // Optional. If not set, a random UUID will be used as prefix and an auto-incrementing counter as suffix. + CorrelationIdSupplier CorrelationIdSupplier + // Function to extract correlation IDs from replies. + // + // Optional. If not set, the `CorrelationID` message property will be used. + CorrelationIdExtractor CorrelationIdExtractor + // Function to modify requests before they are sent. + // + // Optional. If not set, the default `RequestPostProcessor` assigns the correlation ID to the `MessageID` property. + RequestPostProcessor RequestPostProcessor + // The timeout for requests. + // + // Optional. If not set, a default timeout of 30 seconds will be used. + RequestTimeout time.Duration +} + +type outstandingRequest struct { + sentAt time.Time + ch chan *amqp.Message + // TODO: chat to Gabriele about this: shall we communicate via an error that the request timed out? + // or shall we just close the channel and document that if channel is closed and received is nil, it means the request timed out? + // err error +} + +type amqpRpcClient struct { + requestQueue ITargetAddress + replyToQueue ITargetAddress + publisher *Publisher + consumer *Consumer + requestPostProcessor RequestPostProcessor + correlationIdSupplier CorrelationIdSupplier + correlationIdExtractor CorrelationIdExtractor + requestTimeout time.Duration + mu sync.Mutex + pendingRequests map[any]*outstandingRequest + closed bool + done chan struct{} + closer sync.Once +} + +// Close shuts down the RPC client, closing its underlying publisher and consumer. +// It ensures that all pending requests are cleaned up by closing their respective +// channels. This method is safe to call multiple times. +func (a *amqpRpcClient) Close(ctx context.Context) error { + var err error + a.closer.Do(func() { + a.mu.Lock() + defer a.mu.Unlock() + a.closed = true + var err1 error + if err1 = a.publisher.Close(ctx); err1 != nil { + Warn("failed to close publisher", "error", err1) + } + var err2 error + if err2 = a.consumer.Close(ctx); err2 != nil { + Warn("failed to close consumer", "error", err2) + } + err = errors.Join(err1, err2) + for k, req := range a.pendingRequests { + close(req.ch) + delete(a.pendingRequests, k) + } + close(a.done) + }) + return err +} + +func (a *amqpRpcClient) Message(body []byte) *amqp.Message { + return amqp.NewMessage(body) +} + +// Publish sends an RPC request message and returns a channel that will receive the reply. +// It first checks if the client is closed. If not, it generates a correlation ID, +// post-processes the message using the configured `RequestPostProcessor`, +// and then publishes the message. If the message is accepted by RabbitMQ, +// an `outstandingRequest` is created and stored, and a channel is returned +// for the reply. The channel will be closed if the request times out or the +// client is closed before a reply is received. +func (a *amqpRpcClient) Publish(ctx context.Context, message *amqp.Message) (<-chan *amqp.Message, error) { + if a.isClosed() { + return nil, fmt.Errorf("rpc client is closed") + } + replyTo, err := a.replyToQueue.toAddress() + if err != nil { + return nil, fmt.Errorf("failed to set reply-to address: %w", err) + } + if message.Properties == nil { + message.Properties = &amqp.MessageProperties{} + } + message.Properties.ReplyTo = &replyTo + correlationID := a.correlationIdSupplier.Get() + m := a.requestPostProcessor(message, correlationID) + pr, err := a.publisher.Publish(ctx, m) + if err != nil { + return nil, fmt.Errorf("failed to publish request: %w", err) + } + + switch pr.Outcome.(type) { + case *StateAccepted: + Debug("RabbitMQ accepted the request", "correlationID", correlationID) + default: + return nil, fmt.Errorf("RabbitMQ did not accept the request: %s", pr.Outcome) + } + + ch := make(chan *amqp.Message, 1) + a.mu.Lock() + a.pendingRequests[correlationID] = &outstandingRequest{ + sentAt: time.Now(), + ch: ch, + } + a.mu.Unlock() + return ch, nil +} + +func (a *amqpRpcClient) isClosed() bool { + a.mu.Lock() + defer a.mu.Unlock() + return a.closed +} + +// requestTimeoutTask is a goroutine that periodically checks for timed-out RPC requests. +// It runs a ticker and, when triggered, iterates through the pending requests. +// If a request's `sentAt` timestamp is older than the `requestTimeout`, +// its channel is closed, and the request is removed from `pendingRequests`. +// The goroutine exits when the `done` channel is closed, typically when the client is closed. +func (a *amqpRpcClient) requestTimeoutTask() { + t := time.NewTicker(a.requestTimeout) + defer t.Stop() + for { + select { + case <-t.C: + limit := time.Now().Add(-a.requestTimeout) + a.mu.Lock() + maps.DeleteFunc(a.pendingRequests, func(k any, request *outstandingRequest) bool { + if request.sentAt.Before(limit) { + close(request.ch) + Warn("request timed out", "correlationID", k) + return true + } + return false + }) + a.mu.Unlock() + case <-a.done: + return + } + } +} + +// messageReceivedHandler is a goroutine that continuously receives messages from the reply queue. +// It extracts the correlation ID from each received message and attempts to match it with +// an `outstandingRequest`. If a match is found, the reply message is sent to the +// corresponding request's channel, and the request is removed from `pendingRequests`. +// If no match is found, the message is requeued. The goroutine exits when the `done` +// channel is closed, typically when the client is closed. +func (a *amqpRpcClient) messageReceivedHandler() { + for { + select { + case <-a.done: + Debug("rpc client message handler exited") + return + default: + } + + dc, err := a.consumer.Receive(context.Background()) + if err != nil { + Warn("failed to receive message", "error", err) + continue + } + + m := dc.Message() + correlationID := a.correlationIdExtractor(m) + a.mu.Lock() + pendingRequest, exists := a.pendingRequests[correlationID] + if exists { + delete(a.pendingRequests, correlationID) + a.mu.Unlock() + pendingRequest.ch <- m + close(pendingRequest.ch) + err := dc.Accept(context.Background()) + if err != nil { + Warn("error accepting reply", "error", err) + } + } else { + a.mu.Unlock() + Warn("received reply for unknown correlation ID", "correlationID", correlationID) + err := dc.Requeue(context.Background()) + if err != nil { + Warn("error requeuing reply", "error", err) + } + } + } +} diff --git a/pkg/rabbitmqamqp/rpc_client_test.go b/pkg/rabbitmqamqp/rpc_client_test.go new file mode 100644 index 0000000..59a9092 --- /dev/null +++ b/pkg/rabbitmqamqp/rpc_client_test.go @@ -0,0 +1,152 @@ +package rabbitmqamqp + +import ( + "context" + "fmt" + "time" + + "github.com/Azure/go-amqp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("RpcClient", func() { + var ( + conn *AmqpConnection + queueName string + consumer *Consumer + publisher *Publisher + ) + + var pongRpcServer = func(ctx context.Context, publisher *Publisher, consumer *Consumer) { + for { + select { + case <-ctx.Done(): + return + default: + // Receive a message from the server consumer + receivedMessage, err := consumer.Receive(ctx) + if err != nil { + // Exit if we can't receive messages (e.g., + // consumer is closed) + GinkgoWriter.Printf("Error receiving message: %v\n", err) + return + } + + msg := receivedMessage.Message() + if msg == nil { + GinkgoWriter.Printf("Received nil message\n") + continue + } + + // Create response with "Pong: " prefix + responseData := "Pong: " + string(msg.GetData()) + replyMessage := amqp.NewMessage([]byte(responseData)) + + // Copy correlation ID and reply-to from request + if msg.Properties != nil { + if replyMessage.Properties == nil { + replyMessage.Properties = &amqp.MessageProperties{} + } + replyMessage.Properties.CorrelationID = + msg.Properties.MessageID + } + + // Send reply to the specified reply-to address + if msg.Properties != nil && msg.Properties.ReplyTo != nil { + replyMessage.Properties.To = msg.Properties.ReplyTo + } + + copyApplicationProperties(msg, replyMessage) + + publisher.Publish(ctx, replyMessage) + } + } + } + + BeforeEach(func() { + queueName = generateNameWithDateTime(CurrentSpecReport().LeafNodeText) + var err error + conn, err = declareQueueAndConnection(queueName) + Expect(err).ToNot(HaveOccurred()) + consumer, err = conn.NewConsumer(context.Background(), queueName, nil) + Expect(err).ToNot(HaveOccurred()) + publisher, err = conn.NewPublisher(context.Background(), nil, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + _ = consumer.Close(context.Background()) + _ = publisher.Close(context.Background()) + _ = conn.Close(context.Background()) + }) + + It("should send a request and receive replies", func(ctx SpecContext) { + // Server goroutine to handle incoming requests + go pongRpcServer(ctx, publisher, consumer) + + client, err := conn.NewRpcClient(ctx, &RpcClientOptions{ + RequestQueueName: queueName, + }) + Ω(err).ShouldNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + // Closing twice in case the test fails and the 'happy path' close is not called + _ = client.Close(ctx) + }) + + for i := 0; i < 10; i++ { + m := client.Message([]byte(fmt.Sprintf("Message %d", i))) + replyCh, err := client.Publish(ctx, m) + Ω(err).ShouldNot(HaveOccurred()) + actualReply := &amqp.Message{} + Eventually(replyCh). + Within(time.Second). + WithPolling(time.Millisecond * 100). + Should(Receive(&actualReply)) + Expect(actualReply.GetData()).To(BeEquivalentTo(fmt.Sprintf("Pong: Message %d", i))) + } + Ω(client.Close(ctx)).Should(Succeed()) + }) + + It("uses a custom correlation id extractor and post processor", func(ctx SpecContext) { + go pongRpcServer(ctx, publisher, consumer) + client, err := conn.NewRpcClient(ctx, &RpcClientOptions{ + RequestQueueName: queueName, + CorrelationIdExtractor: func(message *amqp.Message) any { + return message.ApplicationProperties["correlationId"] + }, + RequestPostProcessor: func(request *amqp.Message, correlationID any) *amqp.Message { + if request.ApplicationProperties == nil { + request.ApplicationProperties = make(map[string]any) + } + request.ApplicationProperties["correlationId"] = correlationID + if request.Properties == nil { + request.Properties = &amqp.MessageProperties{} + } + request.Properties.MessageID = correlationID + + return request + }, + }) + Ω(err).ShouldNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + // Closing twice in case the test fails and the 'happy path' close is not called + _ = client.Close(ctx) + }) + + request := client.Message([]byte("Using a custom correlation id extractor and post processor")) + request.ApplicationProperties = map[string]any{"this-property": "should-be-preserved"} + replyCh, err := client.Publish(ctx, request) + Ω(err).ShouldNot(HaveOccurred()) + + actualReply := &amqp.Message{} + Eventually(replyCh). + Within(time.Second). + WithPolling(time.Millisecond * 100). + Should(Receive(&actualReply)) + Expect(actualReply.GetData()).To(BeEquivalentTo("Pong: Using a custom correlation id extractor and post processor")) + Expect(actualReply.ApplicationProperties).To(HaveKey("correlationId")) + Expect(actualReply.ApplicationProperties).To(HaveKeyWithValue("this-property", "should-be-preserved")) + Ω(client.Close(ctx)).Should(Succeed()) + }) +}) diff --git a/pkg/rabbitmqamqp/rpc_e2e_test.go b/pkg/rabbitmqamqp/rpc_e2e_test.go new file mode 100644 index 0000000..c2ecf19 --- /dev/null +++ b/pkg/rabbitmqamqp/rpc_e2e_test.go @@ -0,0 +1,214 @@ +package rabbitmqamqp_test + +import ( + "context" + "fmt" + "sync" + + "github.com/Azure/go-amqp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/rabbitmq/rabbitmq-amqp-go-client/pkg/rabbitmqamqp" +) + +// fibonacci calculates and returns the Fibonacci number of x. +// For x <= 1, it returns x. For x > 1, it calculates iteratively to avoid +// stack overflow for large values. +func fibonacci(x int) int { + if x <= 1 { + return x + } + + a, b := 0, 1 + for i := 2; i <= x; i++ { + a, b = b, a+b + } + return b +} + +var _ = Describe("RPC E2E", Label("e2e"), func() { + var ( + clientConn *rabbitmqamqp.AmqpConnection + serverConn *rabbitmqamqp.AmqpConnection + rpcClient rabbitmqamqp.RpcClient + rpcServer rabbitmqamqp.RpcServer + rpcQueueName string + ) + + const ( + rpcQueueNamePrefix = "rpc-e2e-test" + ) + + BeforeEach(func(ctx SpecContext) { + var err error + clientConn, err = rabbitmqamqp.Dial(ctx, "amqp://localhost:5672", &rabbitmqamqp.AmqpConnOptions{ + Properties: map[string]any{"connection_name": "rpc client e2e test"}, + }) + Ω(err).ShouldNot(HaveOccurred()) + serverConn, err = rabbitmqamqp.Dial(ctx, "amqp://localhost:5672", &rabbitmqamqp.AmqpConnOptions{ + Properties: map[string]any{"connection_name": "rpc server e2e test"}, + }) + Ω(err).ShouldNot(HaveOccurred()) + + rpcQueueName = fmt.Sprintf("%s_%s_%d", rpcQueueNamePrefix, CurrentSpecReport().LeafNodeText, GinkgoParallelProcess()) + _, err = serverConn.Management().DeclareQueue(ctx, &rabbitmqamqp.ClassicQueueSpecification{ + Name: rpcQueueName, + IsAutoDelete: true, + }) + Ω(err).ShouldNot(HaveOccurred()) + }) + + AfterEach(func(ctx SpecContext) { + _ = rpcClient.Close(ctx) + _ = rpcServer.Close(ctx) + Ω(clientConn.Close(ctx)).Should(Succeed()) + Ω(serverConn.Close(ctx)).Should(Succeed()) + }) + + It("should work with minimal options", func(ctx SpecContext) { + m := sync.Mutex{} + messagesReceivedByServer := 0 + var err error + rpcClient, err = clientConn.NewRpcClient(ctx, &rabbitmqamqp.RpcClientOptions{ + RequestQueueName: rpcQueueName, + }) + Ω(err).ShouldNot(HaveOccurred()) + rpcServer, err = serverConn.NewRpcServer(ctx, rabbitmqamqp.RpcServerOptions{ + RequestQueue: rpcQueueName, + Handler: func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { + m.Lock() + messagesReceivedByServer += 1 + m.Unlock() + reply, err := rabbitmqamqp.NewMessageWithAddress([]byte{}, &rabbitmqamqp.QueueAddress{ + Queue: *request.Properties.ReplyTo, + }) + if err != nil { + panic(err) + } + x := request.ApplicationProperties["x"].(int64) // int64 because codec encodes int as int64 + reply.ApplicationProperties = map[string]any{"fib": fibonacci(int(x)), "x": x} + return reply, nil + }, + }) + Ω(err).ShouldNot(HaveOccurred()) + + By("sending and waiting sequentially") + var expectedFibonacciNumbers [10]int = [10]int{1, 1, 2, 3, 5, 8, 13, 21, 34, 55} + for i := 1; i <= 10; i++ { + msg := rpcClient.Message([]byte{}) + msg.ApplicationProperties = map[string]any{"x": i} + pendingRequestCh, err := rpcClient.Publish(ctx, msg) + Ω(err).ShouldNot(HaveOccurred()) + select { + case m := <-pendingRequestCh: + Ω(m.ApplicationProperties["fib"]).Should(BeEquivalentTo(expectedFibonacciNumbers[i-1])) + Ω(m.ApplicationProperties["x"]).Should(BeEquivalentTo(i)) + case <-ctx.Done(): + Fail(ctx.Err().Error()) + } + } + Expect(messagesReceivedByServer).To(Equal(10)) + + By("sending a batch and receiving replies") + responseChans := make([]<-chan *amqp.Message, 0, 10) + for i := 1; i <= 10; i++ { + msg := rpcClient.Message([]byte{}) + msg.ApplicationProperties = map[string]any{"x": i} + ch, err := rpcClient.Publish(ctx, msg) + Ω(err).ShouldNot(HaveOccurred()) + responseChans = append(responseChans, ch) + } + + for i, ch := range responseChans { + select { + case m := <-ch: + Ω(m.ApplicationProperties["fib"]).Should(BeEquivalentTo(expectedFibonacciNumbers[i])) + Ω(m.ApplicationProperties["x"]).Should(BeEquivalentTo(i + 1)) + case <-ctx.Done(): + Fail(ctx.Err().Error()) + } + } + Expect(messagesReceivedByServer).To(Equal(20)) + + Expect(rpcClient.Close(ctx)).To(Succeed()) + Expect(rpcServer.Close(ctx)).To(Succeed()) + }) +}) + +func ExampleRpcClient() { + // open a connection + conn, err := rabbitmqamqp.Dial(context.TODO(), "amqp://localhost:5672", &rabbitmqamqp.AmqpConnOptions{ + Properties: map[string]any{"connection_name": "example rpc client"}, + }) + if err != nil { + panic(err) + } + defer conn.Close(context.TODO()) + + // Create RPC client options + // RequestQueueName is mandatory. The queue must exist. + options := rabbitmqamqp.RpcClientOptions{ + RequestQueueName: "rpc-queue", + } + // Create a new RPC client + rpcClient, err := conn.NewRpcClient(context.TODO(), &options) + if err != nil { + panic(err) + } + defer rpcClient.Close(context.TODO()) + + // Create an AMQP message with some initial data + msg := rpcClient.Message([]byte("hello world")) + // Add some application properties to the message + msg.ApplicationProperties = map[string]any{"example": "rpc"} + + // Send the message to the server + pendingRequestCh, err := rpcClient.Publish(context.TODO(), msg) + if err != nil { + panic(err) + } + + // Wait for the reply from the server + replyFromServer := <-pendingRequestCh + // Print the reply from the server + // This example assumes that the server is an "echo" server, that just returns the message it received. + fmt.Printf("application property 'example': %s\n", replyFromServer.ApplicationProperties["example"]) + fmt.Printf("reply correlation ID: %s\n", replyFromServer.Properties.CorrelationID) +} + +type fooCorrelationIdSupplier struct { + count int +} + +func (c *fooCorrelationIdSupplier) Get() any { + c.count++ + return fmt.Sprintf("foo-%d", c.count) +} + +func ExampleRpcClient_customCorrelationID() { + // type fooCorrelationIdSupplier struct { + // count int + // } + // + // func (c *fooCorrelationIdSupplier) Get() any { + // c.count++ + // return fmt.Sprintf("foo-%d", c.count) + // } + + // Connection setup + conn, _ := rabbitmqamqp.Dial(context.TODO(), "amqp://", nil) + defer conn.Close(context.TODO()) + + // Create RPC client options + options := rabbitmqamqp.RpcClientOptions{ + RequestQueueName: "rpc-queue", // the queue must exist + CorrelationIdSupplier: &fooCorrelationIdSupplier{}, + } + + // Create a new RPC client + rpcClient, _ := conn.NewRpcClient(context.TODO(), &options) + pendingRequestCh, _ := rpcClient.Publish(context.TODO(), rpcClient.Message([]byte("hello world"))) + replyFromServer := <-pendingRequestCh + fmt.Printf("reply correlation ID: %s\n", replyFromServer.Properties.CorrelationID) + // Should print: foo-1 +} diff --git a/pkg/rabbitmqamqp/rpc_server.go b/pkg/rabbitmqamqp/rpc_server.go new file mode 100644 index 0000000..1f593b8 --- /dev/null +++ b/pkg/rabbitmqamqp/rpc_server.go @@ -0,0 +1,270 @@ +package rabbitmqamqp + +import ( + "context" + "fmt" + "sync" + "time" + + amqp "github.com/Azure/go-amqp" +) + +// RpcServerHandler is a function that processes a request message and returns a response message. +// If the server wants to send a response to the client, it must return a response message. +// If the function returns nil, the server will not send a response. +// If the server does not send a response message, this high level RPC server doesn't make much sense, +// and it is better to use a normal AMQP 1.0 consumer. +// +// The server handler blocks until this function returns. It is highly recommended to use functions that process and return quickly. +// +// Example: +// +// func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { +// return amqp.NewMessage([]byte(fmt.Sprintf("Pong: %s", request.GetData()))), nil +// } +type RpcServerHandler func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) + +var noOpHandler RpcServerHandler = func(_ context.Context, _ *amqp.Message) (*amqp.Message, error) { + return nil, nil +} + +// CorrelationIdExtractor defines the signature for a function that extracts the correlation ID +// from an AMQP message. Then returned value must be a valid AMQP type that can be binary encoded. +type CorrelationIdExtractor func(message *amqp.Message) any + +var defaultCorrelationIdExtractor CorrelationIdExtractor = func(message *amqp.Message) any { + if message.Properties == nil { + return nil + } + return message.Properties.MessageID +} + +// ReplyPostProcessor is a function that is called after the request handler has processed the request. +// It can be used to modify the reply message before it is sent. +type ReplyPostProcessor func(reply *amqp.Message, correlationID any) *amqp.Message + +var defaultReplyPostProcessor ReplyPostProcessor = func(reply *amqp.Message, correlationID any) *amqp.Message { + if reply == nil { + return nil + } + if reply.Properties == nil { + reply.Properties = &amqp.MessageProperties{} + } + reply.Properties.CorrelationID = correlationID + return reply +} + +// RpcServer is Remote Procedure Call server that receives a message, process them, +// and sends a response. +type RpcServer interface { + // Close the RPC server and its underlying resources. + Close(context.Context) error + // Pause the server to stop receiving messages. + Pause() + // Unpause requests to receive messages again. + Unpause() error +} + +type RpcServerOptions struct { + // RequestQueue is the name of the queue to subscribe to. This queue must be pre-declared. + // The RPC server does not declare the queue, it is the responsibility of the caller to declare the queue. + // + // Mandatory. + RequestQueue string + // Handler is a function to process the request message. If the server wants to send a response to + // the client, it must return a response message. If the function returns nil, the server will not send a response. + // + // It is encouraged to initialise the response message properties in the handler. If the handler returns a non-nil + // error, the server will discard the request message and log an error. + // + // The server handler blocks until this function returns. It is highly recommended to functions that process and return quickly. + // If you need to perform a long running operation, it's advisable to dispatch the operation to another queue. + // + // Example: + // func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { + // return amqp.NewMessage([]byte(fmt.Sprintf("Pong: %s", request.GetData()))), nil + // } + // + // Mandatory. + Handler RpcServerHandler + // CorrectionIdExtractor is a function that extracts a correction ID from the request message. + // The returned value should be an AMQP type that can be binary encoded. + // + // This field is optional. If not provided, the server will use the `MessageID` as the correlation ID. + // + // Example: + // func(message *amqp.Message) any { + // return message.Properties.MessageID + // } + // + // The default correlation ID extractor also handles nil cases. + // + // Optional. + CorrelationIdExtractor CorrelationIdExtractor + // PostProcessor is a function that receives the reply message and the extracted correlation ID, just before the reply is sent. + // It can be used to modify the reply message before it is sent. + // + // The post processor must set the correlation ID in the reply message properties. + // + // This field is optional. If not provided, the server will set the correlation ID in the reply message properties, using + // the correlation ID extracted from the CorrelationIdExtractor. + // + // Example: + // func(reply *amqp.Message, correlationID any) *amqp.Message { + // reply.Properties.CorrelationID = correlationID + // return reply + // } + // + // The default post processor also handles nil cases. + // + // Optional. + ReplyPostProcessor ReplyPostProcessor +} + +func (r *RpcServerOptions) validate() error { + if r.RequestQueue == "" { + return fmt.Errorf("requestQueue is mandatory") + } + return nil +} + +type amqpRpcServer struct { + // TODO: handle state changes for reconnections + mu sync.Mutex + requestHandler RpcServerHandler + requestQueue string + publisher *Publisher + consumer *Consumer + closer sync.Once + closed bool + correlationIdExtractor CorrelationIdExtractor + replyPostProcessor ReplyPostProcessor +} + +// Close closes the RPC server and its underlying AMQP resources. It ensures that these resources +// are closed gracefully and only once, even if Close is called multiple times. +// The provided context (ctx) controls the timeout for the close operation, ensuring the operation +// does not exceed the context's deadline. +func (a *amqpRpcServer) Close(ctx context.Context) error { + // TODO: wait for unsettled messages + a.closer.Do(func() { + a.mu.Lock() + defer a.mu.Unlock() + a.closed = true + // TODO: set a context timeout for the publisher and consumer close operations + if a.publisher != nil { + err := a.publisher.Close(ctx) + if err != nil { + Error("Failed to close publisher", "error", err) + } + } + if a.consumer != nil { + err := a.consumer.Close(ctx) + if err != nil { + Error("Failed to close consumer", "error", err) + } + } + }) + return nil +} + +func (a *amqpRpcServer) Pause() { + err := a.consumer.pause(context.Background()) + if err != nil { + Warn("Did not pause consumer", "error", err) + } +} + +func (a *amqpRpcServer) Unpause() error { + a.mu.Lock() + if a.closed { + a.mu.Unlock() + return nil + } + a.mu.Unlock() + + err := a.consumer.unpause(1) + if err != nil { + return fmt.Errorf("error unpausing RPC server: %w", err) + } + return nil +} + +func (a *amqpRpcServer) handle() { + /* + The RPC server has the following behavior: + - when receiving a message request: + - it calls the processing logic (handler) + - it extracts the correlation ID + - it calls a reply post-processor if defined + - it sends the reply message + - if all these operations succeed, the server accepts the request message (settles it with the ACCEPTED outcome) + - if any of these operations throws an exception, the server discards the request message (the message is + removed from the request queue and is dead-lettered if configured) + */ + for { + if a.isClosed() { + Debug("RPC server is closed. Stopping the handler") + return + } + + err := a.issueCredits(1) + if err != nil { + Warn("Failed to request credits", "error", err) + continue + } + + request, err := a.consumer.Receive(context.Background()) + if err != nil { + Debug("Receive request returned error. This may be expected if the server is closing", "error", err) + continue + } + // TODO: add a configurable timeout for the request handling + reply, err := a.requestHandler(context.Background(), request.message) + if err != nil { + Error("Request handler returned error. Discarding request", "error", err) + request.Discard(context.Background(), nil) + continue + } + + if reply != nil && request.message.Properties != nil && request.message.Properties.ReplyTo != nil { + setToProperty(reply, request.message.Properties.ReplyTo) + } + + correlationID := a.correlationIdExtractor(request.message) + reply = a.replyPostProcessor(reply, correlationID) + if reply != nil { + err = callAndMaybeRetry(func() error { + r, err := a.publisher.Publish(context.Background(), reply) + if err != nil { + return err + } + switch r.Outcome.(type) { + case *StateAccepted: + return nil + } + return fmt.Errorf("reply message not accepted: %s", r.Outcome) + }, []time.Duration{time.Second, 3 * time.Second, 5 * time.Second, 10 * time.Second}) + if err != nil { + Error("Failed to publish reply", "error", err, "correlationId", reply.Properties.CorrelationID) + request.Discard(context.Background(), nil) + continue + } + } + + err = request.Accept(context.Background()) + if err != nil { + Error("Failed to accept request", "error", err, "messageId", request.message.Properties.MessageID) + } + } +} + +func (a *amqpRpcServer) isClosed() bool { + a.mu.Lock() + defer a.mu.Unlock() + return a.closed +} + +func (a *amqpRpcServer) issueCredits(credits uint32) error { + return a.consumer.issueCredits(credits) +} diff --git a/pkg/rabbitmqamqp/rpc_server_test.go b/pkg/rabbitmqamqp/rpc_server_test.go new file mode 100644 index 0000000..89fe475 --- /dev/null +++ b/pkg/rabbitmqamqp/rpc_server_test.go @@ -0,0 +1,180 @@ +package rabbitmqamqp + +import ( + "context" + "fmt" + "log/slog" + "time" + + amqp "github.com/Azure/go-amqp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var _ = Describe("RpcServer", func() { + var ( + conn *AmqpConnection + requestQueue string + ) + + BeforeEach(func() { + requestQueue = generateNameWithDateTime(CurrentSpecReport().LeafNodeText) + var err error + conn, err = declareQueueAndConnection(requestQueue) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func(ctx SpecContext) { + conn.Close(ctx) + }) + + It("process incoming message requests", func(ctx SpecContext) { + // setup + processedMessage := make(chan string, 5) + + replyQueue, err := conn.management.DeclareQueue( + context.Background(), + &ClassicQueueSpecification{Name: fmt.Sprintf("reply-queue_%s", CurrentSpecReport().LeafNodeText), IsExclusive: true}, + ) + Expect(err).ToNot(HaveOccurred()) + + replyConsumer, err := conn.NewConsumer(context.Background(), replyQueue.Name(), nil) + Expect(err).ToNot(HaveOccurred()) + requestPublisher, err := conn.NewPublisher(context.Background(), &QueueAddress{Queue: requestQueue}, nil) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + replyConsumer.Close(ctx) + requestPublisher.Close(ctx) + }) + + server, err := conn.NewRpcServer(context.Background(), RpcServerOptions{ + RequestQueue: requestQueue, + Handler: func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { + if request.Properties == nil { + return nil, fmt.Errorf("request properties are nil") + } + messageID, ok := request.Properties.MessageID.(string) + if !ok { + return nil, fmt.Errorf("correlation ID is not a string") + } + processedMessage <- messageID + reply := amqp.NewMessage([]byte("reply")) + return reply, nil + }, + }) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + server.Close(ctx) + }, NodeTimeout(time.Second*10)) + + // act + message := amqp.NewMessage([]byte("message 1")) + q, err := (&QueueAddress{Queue: replyQueue.Name()}).toAddress() + Expect(err).ToNot(HaveOccurred()) + message.Properties = &amqp.MessageProperties{ + MessageID: "1", + ReplyTo: &q, + } + res, err := requestPublisher.Publish(ctx, message) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Outcome).To(BeAssignableToTypeOf(&StateAccepted{}), "expected rabbit to confirm the message") + + // assert + Eventually(processedMessage).Within(time.Second).Should(Receive(Equal("1"))) + serverReply, err := replyConsumer.Receive(ctx) + Expect(err).ToNot(HaveOccurred()) + m := serverReply.Message() + Expect(m).ToNot(BeNil()) + Expect(m.GetData()).To(BeEquivalentTo("reply")) + Expect(m.Properties.CorrelationID).To(BeEquivalentTo("1")) + }, SpecTimeout(time.Second*10)) + + It("stops the handler when the RPC server closes", func(ctx SpecContext) { + // setup + server, err := conn.NewRpcServer(context.Background(), RpcServerOptions{ + RequestQueue: requestQueue, + }) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + server.Close(ctx) + }, NodeTimeout(time.Second*10)) + + buf := gbytes.NewBuffer() + SetSlogHandler(NewGinkgoHandler(slog.LevelDebug, buf)) + time.Sleep(time.Second) // ugly but necessary to wait for the server to call Receive() and block + + // act + server.Close(ctx) + + // assert + Eventually(buf).Within(time.Second).Should(gbytes.Say("Receive request returned error. This may be expected if the server is closing")) + Eventually(buf).Within(time.Second).Should(gbytes.Say("RPC server is closed. Stopping the handler")) + }, SpecTimeout(time.Second*10)) + + It("uses a custom correlation id extractor and post processor", func(ctx SpecContext) { + // setup + replyQueue, err := conn.management.DeclareQueue( + context.Background(), + &ClassicQueueSpecification{Name: fmt.Sprintf("reply-queue_%s", CurrentSpecReport().LeafNodeText), IsExclusive: true}, + ) + Expect(err).ToNot(HaveOccurred()) + + replyConsumer, err := conn.NewConsumer(context.Background(), replyQueue.Name(), nil) + Expect(err).ToNot(HaveOccurred()) + requestPublisher, err := conn.NewPublisher(context.Background(), &QueueAddress{Queue: requestQueue}, nil) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + replyConsumer.Close(ctx) + requestPublisher.Close(ctx) + }) + + correlationIdExtractor := func(message *amqp.Message) any { + return message.ApplicationProperties["message-id"] + } + postProcessor := func(reply *amqp.Message, correlationID any) *amqp.Message { + reply.Properties.CorrelationID = correlationID + reply.ApplicationProperties["test"] = "success" + return reply + } + server, err := conn.NewRpcServer(context.Background(), RpcServerOptions{ + RequestQueue: requestQueue, + Handler: func(ctx context.Context, request *amqp.Message) (*amqp.Message, error) { + m := amqp.NewMessage(request.GetData()) + m.Properties = &amqp.MessageProperties{} + m.ApplicationProperties = make(map[string]any) + return m, nil + }, + CorrelationIdExtractor: correlationIdExtractor, + ReplyPostProcessor: postProcessor, + }) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func(ctx SpecContext) { + server.Close(ctx) + }, NodeTimeout(time.Second*10)) + + // act + message := amqp.NewMessage([]byte("message with custom correlation id extractor and custom post processor")) + q, err := (&QueueAddress{Queue: replyQueue.Name()}).toAddress() + Expect(err).ToNot(HaveOccurred()) + message.Properties = &amqp.MessageProperties{ + MessageID: 123, + ReplyTo: &q, + } + message.ApplicationProperties = map[string]any{ + "message-id": "my-message-id", + } + res, err := requestPublisher.Publish(ctx, message) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Outcome).To(BeAssignableToTypeOf(&StateAccepted{}), "expected rabbit to confirm the message") + + // assert + serverReply, err := replyConsumer.Receive(ctx) + Expect(err).ToNot(HaveOccurred()) + m := serverReply.Message() + Expect(m).ToNot(BeNil()) + Expect(m.GetData()).To(BeEquivalentTo("message with custom correlation id extractor and custom post processor")) + Expect(m.Properties.CorrelationID).To(BeEquivalentTo("my-message-id")) + Expect(m.ApplicationProperties["test"]).To(BeEquivalentTo("success")) + }, SpecTimeout(time.Second*10)) +}) diff --git a/pkg/rabbitmqamqp/test_utils.go b/pkg/rabbitmqamqp/test_utils.go index b4f7a94..dfbd66b 100644 --- a/pkg/rabbitmqamqp/test_utils.go +++ b/pkg/rabbitmqamqp/test_utils.go @@ -1,26 +1,23 @@ package rabbitmqamqp import ( + "context" "fmt" + "io" + "log/slog" + "maps" + "os" "strconv" "time" + + "github.com/Azure/go-amqp" ) func generateNameWithDateTime(name string) string { return fmt.Sprintf("%s_%s", name, strconv.FormatInt(time.Now().Unix(), 10)) } -// Helper function to create string pointers -func stringPtr(s string) *string { - return &s -} - -func uint32Ptr(i uint32) *uint32 { - return &i -} - // create a static date time string for testing - func createDateTime() time.Time { layout := time.RFC3339 value := "2006-01-02T15:04:05Z" @@ -31,7 +28,55 @@ func createDateTime() time.Time { return t } -// convert time to pointer -func timePtr(t time.Time) *time.Time { - return &t +// ptr returns a pointer to the given value of type T +func ptr[T any](v T) *T { + return &v +} + +type GinkgoLogHandler struct { + slog.Handler + w io.Writer +} + +func (h *GinkgoLogHandler) Handle(_ context.Context, r slog.Record) error { + _, err := h.w.Write([]byte(r.Message + "\n")) + return err +} + +func NewGinkgoHandler(level slog.Level, writer io.Writer) slog.Handler { + handlerOptions := &slog.HandlerOptions{ + Level: level, + AddSource: true, + } + + return &GinkgoLogHandler{ + Handler: slog.NewJSONHandler(os.Stdout, handlerOptions), + w: writer, + } +} + +func declareQueueAndConnection(name string) (*AmqpConnection, error) { + connection, err := Dial(context.Background(), "amqp://", nil) + if err != nil { + return nil, err + } + _, err = connection.Management().DeclareQueue(context.Background(), &ClassicQueueSpecification{Name: name, IsAutoDelete: true}) + if err != nil { + return nil, err + } + return connection, nil +} + +func copyApplicationProperties(from *amqp.Message, to *amqp.Message) { + if to == nil || from == nil { + return + } + if from.ApplicationProperties == nil { + to.ApplicationProperties = nil + return + } + if to.ApplicationProperties == nil { + to.ApplicationProperties = make(map[string]any) + } + maps.Copy(to.ApplicationProperties, from.ApplicationProperties) }