diff --git a/go.mod b/go.mod index 0d8c70ff..0a3eabab 100644 --- a/go.mod +++ b/go.mod @@ -32,5 +32,6 @@ require ( golang.org/x/net v0.46.0 // indirect golang.org/x/sys v0.37.0 // indirect golang.org/x/tools v0.38.0 // indirect + google.golang.org/protobuf v1.36.7 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/pkg/ha/ha_consumer.go b/pkg/ha/ha_consumer.go index f9077d0a..fb56b2d3 100644 --- a/pkg/ha/ha_consumer.go +++ b/pkg/ha/ha_consumer.go @@ -156,6 +156,8 @@ func (c *ReliableConsumer) GetInfo() string { return c.getInfo() } +// Deprecated: see consumer.GetLastStoredOffset() +// use QueryOffset() instead func (c *ReliableConsumer) GetLastStoredOffset() int64 { c.mutexConnection.Lock() defer c.mutexConnection.Unlock() @@ -163,13 +165,21 @@ func (c *ReliableConsumer) GetLastStoredOffset() int64 { return c.consumer.GetLastStoredOffset() } -func (c *ReliableConsumer) StoreOffset() error { +// QueryOffset returns the last stored offset for this consumer given its name and stream +func (c *ReliableConsumer) QueryOffset() (int64, error) { c.mutexConnection.Lock() defer c.mutexConnection.Unlock() + return c.consumer.QueryOffset() +} +// StoreOffset stores the current offset for this consumer given its name and stream +func (c *ReliableConsumer) StoreOffset() error { + c.mutexConnection.Lock() + defer c.mutexConnection.Unlock() return c.consumer.StoreOffset() } +// StoreCustomOffset stores a custom offset for this consumer given its name and stream func (c *ReliableConsumer) StoreCustomOffset(offset int64) error { c.mutexConnection.Lock() defer c.mutexConnection.Unlock() diff --git a/pkg/ha/ha_consumer_test.go b/pkg/ha/ha_consumer_test.go index 74ea55c4..48785fd0 100644 --- a/pkg/ha/ha_consumer_test.go +++ b/pkg/ha/ha_consumer_test.go @@ -166,9 +166,10 @@ var _ = Describe("Reliable Consumer", func() { SetConsumerName(clientProvidedName). SetClientProvidedName(clientProvidedName), func(ctx ConsumerContext, _ *amqp.Message) { + defer GinkgoRecover() // call on every message to test the re-connection. offset := ctx.Consumer.GetOffset() - _ = ctx.Consumer.StoreCustomOffset(offset - 1) // commit all except the last one + Expect(ctx.Consumer.StoreCustomOffset(offset - 1)).To(BeNil()) // commit all except the last one // wait the connection drop to ensure correct offset tracking on re-connection if offset == messageToSend/2 { @@ -179,36 +180,26 @@ var _ = Describe("Reliable Consumer", func() { Expect(err).NotTo(HaveOccurred()) Expect(consumer).NotTo(BeNil()) - connectionToDrop := "" - Eventually(func() bool { - connections, err := test_helper.Connections("15672") - if err != nil { - return false - } - for _, connection := range connections { - if connection.ClientProperties.Connection_name == clientProvidedName { - connectionToDrop = connection.Name - return true - } - } - return false - }, time.Second*5). - Should(BeTrue()) - - Expect(connectionToDrop).NotTo(BeEmpty()) // kill the connection - errDrop := test_helper.DropConnection(connectionToDrop, "15672") - Expect(errDrop).NotTo(HaveOccurred()) dropSignal <- struct{}{} + Eventually(func() (bool, error) { return test_helper.IsConnectionAlive(clientProvidedName, "15672") }, 10*time.Second).WithPolling(500*time.Millisecond). + Should(BeTrue(), "check if the connection is alive") + errDrop := test_helper.DropConnectionAndWait(clientProvidedName, "15672", 10*time.Second) + Expect(errDrop).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred()) - Eventually(func() int64 { return consumer.GetLastStoredOffset() }, 10*time.Second). - Should(Equal(int64(98)), "Offset should be 99") + + Eventually(func() (bool, error) { return test_helper.IsConnectionAlive(clientProvidedName, "15672") }, 20*time.Second). + WithPolling(500*time.Millisecond). + Should(BeTrue(), "check if the connection is alive") + + Eventually(func() (int64, error) { return consumer.QueryOffset() }, 10*time.Second).WithPolling(500*time.Millisecond). + Should(Equal(int64(98)), "Offset should be 98") // set a custom offset - Expect(consumer.StoreCustomOffset(99)).NotTo(HaveOccurred()) - Eventually(func() int64 { return consumer.GetLastStoredOffset() }, 1*time.Second). - Should(Equal(int64(99)), "Offset should be 99") + Expect(consumer.StoreCustomOffset(33)).NotTo(HaveOccurred()) + Eventually(func() (int64, error) { return consumer.QueryOffset() }, 1*time.Second). + Should(Equal(int64(33)), "Offset should be 33 due to custom commit") Expect(consumer.Close()).NotTo(HaveOccurred()) }) diff --git a/pkg/integration_test/stream_integration_test.go b/pkg/integration_test/stream_integration_test.go index 96560a93..724903af 100644 --- a/pkg/integration_test/stream_integration_test.go +++ b/pkg/integration_test/stream_integration_test.go @@ -20,7 +20,7 @@ var _ = Describe("StreamIntegration", func() { var ( addresses = []string{ "rabbitmq-stream://guest:guest@localhost:5552/"} - streamName = "test-next" + streamName = fmt.Sprintf("test-next-%d", time.Now().UnixNano()) streamEnv *stream.Environment producer *stream.Producer totalInitialMessages int diff --git a/pkg/stream/client.go b/pkg/stream/client.go index 3f79b7cc..78a5d153 100644 --- a/pkg/stream/client.go +++ b/pkg/stream/client.go @@ -677,6 +677,10 @@ func (c *Client) queryPublisherSequence(publisherReference string, stream string } func (c *Client) BrokerLeader(stream string) (*Broker, error) { + return c.BrokerLeaderWithResolver(stream, nil) +} + +func (c *Client) BrokerLeaderWithResolver(stream string, resolver *AddressResolver) (*Broker, error) { streamsMetadata := c.metaData(stream) if streamsMetadata == nil { return nil, fmt.Errorf("leader error for stream for stream: %s", stream) @@ -693,6 +697,13 @@ func (c *Client) BrokerLeader(stream string) (*Broker, error) { streamMetadata.Leader.advPort = streamMetadata.Leader.Port streamMetadata.Leader.advHost = streamMetadata.Leader.Host + // If AddressResolver is configured, use it directly and skip DNS lookup + if resolver != nil { + streamMetadata.Leader.Host = resolver.Host + streamMetadata.Leader.Port = strconv.Itoa(resolver.Port) + return streamMetadata.Leader, nil + } + res := net.Resolver{} // see: https://github.com/rabbitmq/rabbitmq-stream-go-client/pull/317 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -738,12 +749,30 @@ func (c *Client) BrokerForConsumer(stream string) (*Broker, error) { } brokers := make([]*Broker, 0, 1+len(streamMetadata.Replicas)) - brokers = append(brokers, streamMetadata.Leader) + + // Count available replicas + availableReplicas := 0 + for _, replica := range streamMetadata.Replicas { + if replica != nil { + availableReplicas++ + } + } + + // Only add leader if no replicas are available + if availableReplicas == 0 { + streamMetadata.Leader.advPort = streamMetadata.Leader.Port + streamMetadata.Leader.advHost = streamMetadata.Leader.Host + brokers = append(brokers, streamMetadata.Leader) + } + + // Add all available replicas for idx, replica := range streamMetadata.Replicas { if replica == nil { logs.LogWarn("Stream %s replica not ready: %d", stream, idx) continue } + replica.advPort = replica.Port + replica.advHost = replica.Host brokers = append(brokers, replica) } @@ -875,14 +904,18 @@ func (c *Client) declareSubscriber(streamName string, return nil, fmt.Errorf("specify a valid Offset") } - if options.autoCommitStrategy.flushInterval < 1*time.Second { + if (options.autoCommitStrategy != nil) && (options.autoCommitStrategy.flushInterval < 1*time.Second) && options.autocommit { return nil, fmt.Errorf("flush internal must be bigger than one second") } - if options.autoCommitStrategy.messageCountBeforeStorage < 1 { + if (options.autoCommitStrategy != nil) && options.autoCommitStrategy.messageCountBeforeStorage < 1 && options.autocommit { return nil, fmt.Errorf("message count before storage must be bigger than one") } + if (options.autoCommitStrategy != nil) && options.ConsumerName == "" && options.autocommit { + return nil, fmt.Errorf("consumer name must be set when autocommit is enabled") + } + if messagesHandler == nil { return nil, fmt.Errorf("messages Handler must be set") } diff --git a/pkg/stream/consumer.go b/pkg/stream/consumer.go index 357975bf..d8770515 100644 --- a/pkg/stream/consumer.go +++ b/pkg/stream/consumer.go @@ -101,6 +101,22 @@ func (consumer *Consumer) setPromotedAsActive(promoted bool) { consumer.isPromotedAsActive = promoted } +// Deprecated: The method name may be misleading. +// The method does not indicate the last message stored, but the last stored in memory. +// The method was added to avoid to query the offset from the server, but it created confusion. +// Use `QueryOffset` instead.: +// +// offset, err := consumer.QueryOffset() +// // or: +// offset, err := env.QueryOffset(consumerName, streamName) +// // check the error +// .... +// SetOffset(stream.OffsetSpecification{}.Offset(offset)). +// +// There is an edge case in which multiple clients use the same consumer name, +// and the last stored offset in memory is not the one the user expects. +// So, to avoid confusion, it is better to use QueryOffset, which always gets the value from the server. + func (consumer *Consumer) GetLastStoredOffset() int64 { consumer.mutex.Lock() defer consumer.mutex.Unlock() @@ -312,6 +328,13 @@ func (c *ConsumerOptions) SetClientProvidedName(clientProvidedName string) *Cons return c } +func (c *ConsumerOptions) GetClientProvidedName(defaultClientProvidedName string) string { + if c == nil { + return defaultClientProvidedName + } + return c.ClientProvidedName +} + func (c *ConsumerOptions) SetFilter(filter *ConsumerFilter) *ConsumerOptions { c.Filter = filter return c @@ -452,9 +475,12 @@ func (consumer *Consumer) getLastAutoCommitStored() time.Time { return consumer.lastAutoCommitStored } +// StoreOffset stores the current offset for this consumer given its name and stream func (consumer *Consumer) StoreOffset() error { return consumer.internalStoreOffset() } + +// StoreCustomOffset stores a custom offset for this consumer given its name and stream func (consumer *Consumer) StoreCustomOffset(offset int64) error { consumer.mutex.Lock() defer consumer.mutex.Unlock() @@ -510,7 +536,11 @@ func (consumer *Consumer) writeConsumeUpdateOffsetToSocket(correlationID uint32, return consumer.options.client.socket.writeAndFlush(b.Bytes()) } +// QueryOffset returns the last stored offset for this consumer given its name and stream func (consumer *Consumer) QueryOffset() (int64, error) { + if (consumer.options == nil) || (consumer.options.client == nil) || (consumer.options.ConsumerName == "") || (consumer.options.streamName == "") { + return -1, fmt.Errorf("offset query error: consumer not properly initialized") + } return consumer.options.client.queryOffset(consumer.options.ConsumerName, consumer.options.streamName) } diff --git a/pkg/stream/consumer_test.go b/pkg/stream/consumer_test.go index ce77a50a..8ba281fb 100644 --- a/pkg/stream/consumer_test.go +++ b/pkg/stream/consumer_test.go @@ -219,7 +219,7 @@ var _ = Describe("Streaming Consumers", func() { SetManualCommit(). SetCRCCheck(false)) Expect(err).NotTo(HaveOccurred()) - Eventually(func() int64 { return consumer.GetLastStoredOffset() }, 5*time.Second).Should(Equal(int64(99)), + Eventually(func() (int64, error) { return consumer.QueryOffset() }, 5*time.Second).Should(Equal(int64(99)), "Offset should be 99") Expect(consumer.Close()).NotTo(HaveOccurred()) }) @@ -236,18 +236,18 @@ var _ = Describe("Streaming Consumers", func() { SetCountBeforeStorage(100). SetFlushInterval(50*time.Second))) // here we set a high value to do not trigger the time Expect(err).NotTo(HaveOccurred()) - Eventually(func() int64 { - return consumer.GetLastStoredOffset() + time.Sleep(500 * time.Millisecond) + Eventually(func() (int64, error) { + v, err := consumer.QueryOffset() + // we can ignore the offset not found error here + if err != nil { + return 0, nil + } + return v, err // 99 is the offset since it starts from 0 - }, 5*time.Second).Should(Equal(int64(99)), + }, 5*time.Second).WithPolling(500*time.Millisecond).Should(Equal(int64(99)), "Offset should be 99") Expect(consumer.Close()).NotTo(HaveOccurred()) - /// When the consumer is closed, it has to save the offset - // so the last offset has to be 104 - Eventually(func() int64 { - return consumer.GetLastStoredOffset() - }, 5*time.Second).Should(Equal(int64(104)), - "Offset should be 104") consumerTimer, errTimer := env.NewConsumer(streamName, func(_ ConsumerContext, _ *amqp.Message) { @@ -259,19 +259,16 @@ var _ = Describe("Streaming Consumers", func() { SetCountBeforeStorage(10000000). /// We avoid raising the timer SetFlushInterval(1*time.Second))) Expect(errTimer).NotTo(HaveOccurred()) - time.Sleep(2 * time.Second) - Eventually(func() int64 { - return consumerTimer.GetLastStoredOffset() - }, 5*time.Second).Should(Equal(int64(104)), + Eventually(func() (int64, error) { + v, err := consumerTimer.QueryOffset() + // we can ignore the offset not found error here + if err != nil { + return 0, nil + } + return v, err + }, 5*time.Second).WithPolling(500*time.Millisecond).Should(Equal(int64(104)), "Offset should be 104") Expect(consumerTimer.Close()).NotTo(HaveOccurred()) - /// When the consumer is closed, it has to save the offset - // so the last offest has to be 104 - Eventually(func() int64 { - return consumerTimer.GetLastStoredOffset() - }, 5*time.Second).Should(Equal(int64(104)), - "Offset should be 104") - }) }) @@ -285,6 +282,7 @@ var _ = Describe("Streaming Consumers", func() { func(_ ConsumerContext, _ *amqp.Message) { atomic.AddInt32(&messagesReceived, 1) }, NewConsumerOptions(). + SetConsumerName("autoCommitStrategy"). SetAutoCommit(NewAutoCommitStrategy(). SetCountBeforeStorage(10000000). SetFlushInterval(time.Second))) @@ -294,14 +292,17 @@ var _ = Describe("Streaming Consumers", func() { for i := 0; i < maxMessages; i++ { Expect(producer.Send(CreateMessageForTesting("", i))).NotTo(HaveOccurred()) // emit message before the flush interval has elapsed - time.Sleep(time.Millisecond * 100) + time.Sleep(time.Millisecond * 1100) - if consumer.GetLastStoredOffset() > 0 { + v, err := consumer.QueryOffset() + Expect(err).NotTo(HaveOccurred()) + if v > 0 { break } + } - Expect(messagesReceived > 5 && messagesReceived < int32(maxMessages)).To(BeTrueBecause("%d messages received", messagesReceived)) + Expect(messagesReceived > 0 && messagesReceived < int32(maxMessages)).To(BeTrueBecause("%d messages received", messagesReceived)) Expect(producer.Close()).NotTo(HaveOccurred()) Expect(consumer.Close()).NotTo(HaveOccurred()) }) @@ -404,8 +405,8 @@ var _ = Describe("Streaming Consumers", func() { }, 5*time.Second).Should(Equal(int32(107)), "consumer should receive same messages Send by producer") - Eventually(func() int64 { - return consumer.GetLastStoredOffset() + Eventually(func() (int64, error) { + return consumer.QueryOffset() // 106 is the offset since it starts from 0 }, 5*time.Second).Should(Equal(int64(106)), "Offset should be 106") @@ -710,13 +711,25 @@ var _ = Describe("Streaming Consumers", func() { NewAutoCommitStrategy().SetFlushInterval(10*time.Millisecond))) Expect(err).To(HaveOccurred()) - // message handler must be set + // message specific a valid offset _, err = env.NewConsumer(streamName, nil, &ConsumerOptions{ Offset: OffsetSpecification{}, }) Expect(err).To(HaveOccurred()) + // handler is nil + _, err = env.NewConsumer(streamName, + nil, &ConsumerOptions{ + Offset: OffsetSpecification{ + typeOfs: typeFirst}, + }) + Expect(err).To(HaveOccurred()) + + _, err = env.NewConsumer(streamName, + nil, NewConsumerOptions().SetAutoCommit(NewAutoCommitStrategy())) + Expect(err).To(HaveOccurred()) + }) It("Sub Batch consumer with different publishers GZIP and Not", func() { diff --git a/pkg/stream/environment.go b/pkg/stream/environment.go index ed39469f..366eb4b8 100644 --- a/pkg/stream/environment.go +++ b/pkg/stream/environment.go @@ -457,6 +457,10 @@ func (envOptions *EnvironmentOptions) SetRPCTimeout(timeout time.Duration) *Envi return envOptions } +type clientOptions interface { + GetClientProvidedName(defaultClientProvidedName string) string +} + type environmentCoordinator struct { mutex *sync.Mutex clientsPerContext sync.Map @@ -535,26 +539,34 @@ func (c *Client) maybeCleanConsumers(streamName string) { }) } -func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, streamName string, options *ProducerOptions, rpcTimeout time.Duration, cleanUp func()) (*Producer, error) { +func (cc *environmentCoordinator) newClientEntity( + isListFull func(int) bool, + defaultClientName string, + leader *Broker, + tcpParameters *TCPParameters, + saslConfiguration *SaslConfiguration, + options clientOptions, + rpcTimeout time.Duration, +) (*Client, error) { cc.mutex.Lock() defer cc.mutex.Unlock() var clientResult *Client cc.clientsPerContext.Range(func(key, value any) bool { - if !cc.isProducerListFull(key.(int)) { + if !isListFull(key.(int)) { clientResult = value.(*Client) return false } return true }) - clientProvidedName := "go-stream-producer" - if options != nil && options.ClientProvidedName != "" { - clientProvidedName = options.ClientProvidedName + clientProvidedName := defaultClientName + if options != nil { + clientProvidedName = options.GetClientProvidedName(defaultClientName) } if clientResult == nil { - clientResult = cc.newClientForProducer(clientProvidedName, leader, tcpParameters, saslConfiguration, rpcTimeout) + clientResult = cc.newClientForConnection(clientProvidedName, leader, tcpParameters, saslConfiguration, rpcTimeout) } err := clientResult.connect() @@ -562,68 +574,53 @@ func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCP return nil, err } - for clientResult.connectionProperties.host != leader.advHost || - clientResult.connectionProperties.port != leader.advPort { - logs.LogDebug("connectionProperties host %s doesn't match with the advertised_host %s, advertised_port %s .. retry", - clientResult.connectionProperties.host, - leader.advHost, leader.advPort) - clientResult.Close() - clientResult = cc.newClientForProducer(clientProvidedName, leader, tcpParameters, saslConfiguration, rpcTimeout) - err = clientResult.connect() - if err != nil { - return nil, err - } - time.Sleep(1 * time.Second) - } - - producer, err := clientResult.declarePublisher(streamName, options, cleanUp) + return cc.validateBrokerConnection(clientResult, leader, + func() *Client { + return cc.newClientForConnection(clientProvidedName, leader, tcpParameters, saslConfiguration, rpcTimeout) + }) +} +func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, streamName string, options *ProducerOptions, rpcTimeout time.Duration, cleanUp func()) (*Producer, error) { + client, err := cc.newClientEntity(cc.isProducerListFull, "go-stream-producer", leader, tcpParameters, saslConfiguration, options, rpcTimeout) if err != nil { return nil, err } - - return producer, nil + return client.declarePublisher(streamName, options, cleanUp) } -func (cc *environmentCoordinator) newClientForProducer(connectionName string, leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, rpcTimeOut time.Duration) *Client { - clientResult := newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeOut) - cc.nextId++ - - cc.clientsPerContext.Store(cc.nextId, clientResult) - return clientResult -} - -func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, +func (cc *environmentCoordinator) newConsumer(leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, streamName string, messagesHandler MessagesHandler, options *ConsumerOptions, rpcTimeout time.Duration, cleanUp func()) (*Consumer, error) { - cc.mutex.Lock() - defer cc.mutex.Unlock() - var clientResult *Client - - cc.clientsPerContext.Range(func(key, value any) bool { - if !cc.isConsumerListFull(key.(int)) { - clientResult = value.(*Client) - return false - } - return true - }) - - if clientResult == nil { - clientResult = newClient(connectionName, leader, tcpParameters, saslConfiguration, rpcTimeout) - cc.nextId++ - cc.clientsPerContext.Store(cc.nextId, clientResult) - } - // try to reconnect in case the socket is closed - err := clientResult.connect() + client, err := cc.newClientEntity(cc.isConsumerListFull, "go-stream-consumer", leader, tcpParameters, saslConfiguration, options, rpcTimeout) if err != nil { return nil, err } - subscriber, err := clientResult.declareSubscriber(streamName, messagesHandler, options, cleanUp) - if err != nil { - return nil, err + return client.declareSubscriber(streamName, messagesHandler, options, cleanUp) +} + +func (cc *environmentCoordinator) validateBrokerConnection(client *Client, broker *Broker, newClientFunc func() *Client) (*Client, error) { + for client.connectionProperties.host != broker.advHost || + client.connectionProperties.port != broker.advPort { + logs.LogDebug("connectionProperties host %s doesn't match with the advertised_host %s, advertised_port %s .. retry", + client.connectionProperties.host, + broker.advHost, broker.advPort) + client.Close() + client = newClientFunc() + err := client.connect() + if err != nil { + return nil, err + } + time.Sleep(time.Duration(500+rand.Intn(1000)) * time.Millisecond) } - return subscriber, nil + return client, nil +} + +func (cc *environmentCoordinator) newClientForConnection(connectionName string, broker *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, rpcTimeout time.Duration) *Client { + clientResult := newClient(connectionName, broker, tcpParameters, saslConfiguration, rpcTimeout) + cc.nextId++ + cc.clientsPerContext.Store(cc.nextId, clientResult) + return clientResult } func (cc *environmentCoordinator) Close() error { @@ -664,10 +661,12 @@ func (ps *producersEnvironment) newProducer(clientLocator *Client, streamName st options *ProducerOptions, resolver *AddressResolver, rpcTimeOut time.Duration) (*Producer, error) { ps.mutex.Lock() defer ps.mutex.Unlock() - leader, err := clientLocator.BrokerLeader(streamName) + + leader, err := clientLocator.BrokerLeaderWithResolver(streamName, resolver) if err != nil { return nil, err } + coordinatorKey := leader.hostPort() if ps.producersCoordinator[coordinatorKey] == nil { ps.producersCoordinator[coordinatorKey] = &environmentCoordinator{ @@ -677,6 +676,7 @@ func (ps *producersEnvironment) newProducer(clientLocator *Client, streamName st nextId: 0, } } + leader.cloneFrom(clientLocator.broker, resolver) cleanUp := func() { @@ -684,6 +684,7 @@ func (ps *producersEnvironment) newProducer(clientLocator *Client, streamName st coordinator.maybeCleanClients() } } + producer, err := ps.producersCoordinator[coordinatorKey].newProducer(leader, clientLocator.tcpParameters, clientLocator.saslConfiguration, streamName, options, rpcTimeOut, cleanUp) if err != nil { @@ -728,10 +729,12 @@ func (ps *consumersEnvironment) NewSubscriber(clientLocator *Client, streamName consumerOptions *ConsumerOptions, resolver *AddressResolver, rpcTimeout time.Duration) (*Consumer, error) { ps.mutex.Lock() defer ps.mutex.Unlock() + consumerBroker, err := clientLocator.BrokerForConsumer(streamName) if err != nil { return nil, err } + coordinatorKey := consumerBroker.hostPort() if ps.consumersCoordinator[coordinatorKey] == nil { ps.consumersCoordinator[coordinatorKey] = &environmentCoordinator{ @@ -741,11 +744,9 @@ func (ps *consumersEnvironment) NewSubscriber(clientLocator *Client, streamName nextId: 0, } } + consumerBroker.cloneFrom(clientLocator.broker, resolver) - clientProvidedName := "go-stream-consumer" - if consumerOptions != nil && consumerOptions.ClientProvidedName != "" { - clientProvidedName = consumerOptions.ClientProvidedName - } + cleanUp := func() { for _, coordinator := range ps.consumersCoordinator { coordinator.maybeCleanClients() @@ -753,7 +754,7 @@ func (ps *consumersEnvironment) NewSubscriber(clientLocator *Client, streamName } consumer, err := ps.consumersCoordinator[coordinatorKey]. - newConsumer(clientProvidedName, consumerBroker, clientLocator.tcpParameters, + newConsumer(consumerBroker, clientLocator.tcpParameters, clientLocator.saslConfiguration, streamName, messagesHandler, consumerOptions, rpcTimeout, cleanUp) if err != nil { diff --git a/pkg/stream/producer.go b/pkg/stream/producer.go index 17a020da..646d7eb1 100644 --- a/pkg/stream/producer.go +++ b/pkg/stream/producer.go @@ -189,6 +189,13 @@ func (po *ProducerOptions) SetClientProvidedName(name string) *ProducerOptions { return po } +func (po *ProducerOptions) GetClientProvidedName(defaultClientProvidedName string) string { + if po == nil { + return defaultClientProvidedName + } + return po.ClientProvidedName +} + // SetFilter sets the filter for the producer. See ProducerOptions.Filter for more details func (po *ProducerOptions) SetFilter(filter *ProducerFilter) *ProducerOptions { po.Filter = filter diff --git a/pkg/test-helper/http_utils.go b/pkg/test-helper/http_utils.go index fd9422d6..7bfc3db2 100644 --- a/pkg/test-helper/http_utils.go +++ b/pkg/test-helper/http_utils.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "strconv" + "time" "github.com/pkg/errors" ) @@ -33,6 +34,20 @@ func Connections(port string) ([]connection, error) { return data, nil } +// IsConnectionAlive check if a connection is alive given its client provided name +func IsConnectionAlive(clientProvidedName string, port string) (bool, error) { + connections, err := Connections(port) + if err != nil { + return false, err + } + for _, connection := range connections { + if connection.ClientProperties.Connection_name == clientProvidedName { + return true, nil + } + } + return false, nil +} + func DropConnectionClientProvidedName(clientProvidedName string, port string) error { connections, err := Connections(port) if err != nil { @@ -58,6 +73,30 @@ func DropConnectionClientProvidedName(clientProvidedName string, port string) er return nil } +// drop and wait for the connection to be dropped + +func DropConnectionAndWait(clientProvidedName string, port string, timeout time.Duration) error { + err := DropConnectionClientProvidedName(clientProvidedName, port) + if err != nil { + return err + } + + // wait for the connection to be dropped until timeout + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + isAlive, err := IsConnectionAlive(clientProvidedName, port) + if err != nil { + return err + } + if !isAlive { + return nil + } + time.Sleep(500 * time.Millisecond) + } + + return nil +} + func DropConnection(name string, port string) error { _, err := httpDelete("http://localhost:"+port+"/api/connections/"+name, "guest", "guest") if err != nil {