From 7ab87fe6b2bbc5f42ae35efabd14afc7b6c837b5 Mon Sep 17 00:00:00 2001 From: Nagesh Honnalli Date: Tue, 13 May 2025 22:16:52 -0700 Subject: [PATCH 1/4] [FLINK-34071][Connectors/Kinesis] Decouple NettyEventLoop thread's onNext() by handing over blocking queue put to a separate executor. Using a shared executor across shards to execute processing of event received by NettyEventLoop. Unit tests for shard subscription thread safety, recording starting position correctly, restart processing, record ordering and happy path. --- .../fanout/FanOutKinesisShardSplitReader.java | 136 +++++- .../FanOutKinesisShardSubscription.java | 161 ++++++- .../FanOutKinesisShardHappyPathTest.java | 243 ++++++++++ .../FanOutKinesisShardRecordOrderingTest.java | 426 ++++++++++++++++++ .../fanout/FanOutKinesisShardRestartTest.java | 222 +++++++++ .../FanOutKinesisShardSplitReaderTest.java | 22 +- ...KinesisShardSplitReaderThreadPoolTest.java | 406 +++++++++++++++++ ...anOutKinesisShardStartingPositionTest.java | 210 +++++++++ ...esisShardSubscriptionThreadSafetyTest.java | 414 +++++++++++++++++ .../fanout/FanOutKinesisShardTestBase.java | 198 ++++++++ .../reader/fanout/FanOutKinesisTestUtils.java | 138 ++++++ 11 files changed, 2546 insertions(+), 30 deletions(-) create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardStartingPositionTest.java create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java create mode 100644 flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisTestUtils.java diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java index c0aefee5..10370a53 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java @@ -19,6 +19,7 @@ package org.apache.flink.connector.kinesis.source.reader.fanout; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.base.source.reader.splitreader.SplitsChange; import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics; @@ -26,12 +27,18 @@ import org.apache.flink.connector.kinesis.source.reader.KinesisShardSplitReaderBase; import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; import org.apache.flink.connector.kinesis.source.split.KinesisShardSplitState; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; import java.time.Duration; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT; @@ -45,17 +52,128 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { private final String consumerArn; private final Duration subscriptionTimeout; + /** + * Shared executor service for all shard subscriptions. + * + *

This executor uses an unbounded queue ({@link LinkedBlockingQueue}) to ensure no tasks are ever rejected. + * Although the queue is technically unbounded, the system has natural flow control mechanisms that effectively + * bound the queue size: + * + *

    + *
  1. Each {@link FanOutKinesisShardSubscription} has a bounded event queue with capacity of 2
  2. + *
  3. New records are only requested after processing an event (via {@code requestRecords()})
  4. + *
  5. The maximum number of queued tasks is effectively bounded by {@code 2 * number_of_shards}
  6. + *
+ * + *

This design provides natural backpressure while ensuring no records are dropped, making it safe + * to use an unbounded executor queue. + */ + private final ExecutorService sharedShardSubscriptionExecutor; + private final Map splitSubscriptions = new HashMap<>(); + /** + * Factory for creating subscriptions. This is primarily used for testing. + */ + @VisibleForTesting + public interface SubscriptionFactory { + FanOutKinesisShardSubscription createSubscription( + AsyncStreamProxy proxy, + String consumerArn, + String shardId, + StartingPosition startingPosition, + Duration timeout, + ExecutorService executor); + } + + /** + * Default implementation of the subscription factory. + */ + private static class DefaultSubscriptionFactory implements SubscriptionFactory { + @Override + public FanOutKinesisShardSubscription createSubscription( + AsyncStreamProxy proxy, + String consumerArn, + String shardId, + StartingPosition startingPosition, + Duration timeout, + ExecutorService executor) { + return new FanOutKinesisShardSubscription( + proxy, + consumerArn, + shardId, + startingPosition, + timeout, + executor); + } + } + + private SubscriptionFactory subscriptionFactory; + public FanOutKinesisShardSplitReader( AsyncStreamProxy asyncStreamProxy, String consumerArn, Map shardMetricGroupMap, Configuration configuration) { + this(asyncStreamProxy, consumerArn, shardMetricGroupMap, configuration, new DefaultSubscriptionFactory()); + } + + @VisibleForTesting + FanOutKinesisShardSplitReader( + AsyncStreamProxy asyncStreamProxy, + String consumerArn, + Map shardMetricGroupMap, + Configuration configuration, + SubscriptionFactory subscriptionFactory) { + this( + asyncStreamProxy, + consumerArn, + shardMetricGroupMap, + configuration, + subscriptionFactory, + createDefaultExecutor()); + } + + /** + * Constructor with injected executor service for testing. + * + * @param asyncStreamProxy The proxy for Kinesis API calls + * @param consumerArn The ARN of the consumer + * @param shardMetricGroupMap The metrics map + * @param configuration The configuration + * @param subscriptionFactory The factory for creating subscriptions + * @param executorService The executor service to use for subscription tasks + */ + @VisibleForTesting + FanOutKinesisShardSplitReader( + AsyncStreamProxy asyncStreamProxy, + String consumerArn, + Map shardMetricGroupMap, + Configuration configuration, + SubscriptionFactory subscriptionFactory, + ExecutorService executorService) { super(shardMetricGroupMap, configuration); this.asyncStreamProxy = asyncStreamProxy; this.consumerArn = consumerArn; this.subscriptionTimeout = configuration.get(EFO_CONSUMER_SUBSCRIPTION_TIMEOUT); + this.subscriptionFactory = subscriptionFactory; + this.sharedShardSubscriptionExecutor = executorService; + } + + /** + * Creates the default executor service for subscription tasks. + * + * @return A new executor service + */ + private static ExecutorService createDefaultExecutor() { + int minThreads = Runtime.getRuntime().availableProcessors(); + int maxThreads = minThreads * 2; + return new ThreadPoolExecutor( + minThreads, + maxThreads, + 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(), // Unbounded queue with natural flow control + new ExecutorThreadFactory("kinesis-efo-subscription")); } @Override @@ -80,12 +198,13 @@ public void handleSplitsChanges(SplitsChange splitsChanges) { super.handleSplitsChanges(splitsChanges); for (KinesisShardSplit split : splitsChanges.splits()) { FanOutKinesisShardSubscription subscription = - new FanOutKinesisShardSubscription( + subscriptionFactory.createSubscription( asyncStreamProxy, consumerArn, split.getShardId(), split.getStartingPosition(), - subscriptionTimeout); + subscriptionTimeout, + sharedShardSubscriptionExecutor); subscription.activateSubscription(); splitSubscriptions.put(split.splitId(), subscription); } @@ -93,6 +212,19 @@ public void handleSplitsChanges(SplitsChange splitsChanges) { @Override public void close() throws Exception { + // Shutdown the executor service + if (sharedShardSubscriptionExecutor != null) { + sharedShardSubscriptionExecutor.shutdown(); + try { + if (!sharedShardSubscriptionExecutor.awaitTermination(10, TimeUnit.SECONDS)) { + sharedShardSubscriptionExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + sharedShardSubscriptionExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + asyncStreamProxy.close(); } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java index a299e50a..7d5ba7a1 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java @@ -46,6 +46,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -77,6 +78,21 @@ public class FanOutKinesisShardSubscription { private final Duration subscriptionTimeout; + /** Executor service to run subscription event processing tasks. */ + private final ExecutorService subscriptionEventProcessingExecutor; + + /** + * Lock to ensure sequential processing of subscription events for this shard. + * This lock guarantees that for each shard: + * 1. Only one event is processed at a time + * 2. Events are processed in the order they are received + * 3. The critical operations (queue.put, startingPosition update, requestRecords) are executed atomically + * + *

This is essential to prevent race conditions that could lead to data loss or incorrect + * continuation sequence numbers being used after failover. + */ + private final Object subscriptionEventProcessingLock = new Object(); + // Queue is meant for eager retrieval of records from the Kinesis stream. We will always have 2 // record batches available on next read. private final BlockingQueue eventQueue = new LinkedBlockingQueue<>(2); @@ -86,19 +102,50 @@ public class FanOutKinesisShardSubscription { // Store the current starting position for this subscription. Will be updated each time new // batch of records is consumed private StartingPosition startingPosition; + + /** + * Gets the current starting position for this subscription. + * + * @return The current starting position + */ + public StartingPosition getStartingPosition() { + return startingPosition; + } + + /** + * Checks if the subscription is active. + * + * @return true if the subscription is active, false otherwise + */ + public boolean isActive() { + return subscriptionActive.get(); + } + private FanOutShardSubscriber shardSubscriber; + /** + * Creates a new FanOutKinesisShardSubscription with the specified parameters. + * + * @param kinesis The AsyncStreamProxy to use for Kinesis operations + * @param consumerArn The ARN of the consumer + * @param shardId The ID of the shard to subscribe to + * @param startingPosition The starting position for the subscription + * @param subscriptionTimeout The timeout for the subscription + * @param subscriptionEventProcessingExecutor The executor service to use for processing subscription events + */ public FanOutKinesisShardSubscription( AsyncStreamProxy kinesis, String consumerArn, String shardId, StartingPosition startingPosition, - Duration subscriptionTimeout) { + Duration subscriptionTimeout, + ExecutorService subscriptionEventProcessingExecutor) { this.kinesis = kinesis; this.consumerArn = consumerArn; this.shardId = shardId; this.startingPosition = startingPosition; this.subscriptionTimeout = subscriptionTimeout; + this.subscriptionEventProcessingExecutor = subscriptionEventProcessingExecutor; } /** Method to allow eager activation of the subscription. */ @@ -293,27 +340,14 @@ public void onNext(SubscribeToShardEventStream subscribeToShardEventStream) { new SubscribeToShardResponseHandler.Visitor() { @Override public void visit(SubscribeToShardEvent event) { - try { - LOG.debug( - "Received event: {}, {}", - event.getClass().getSimpleName(), - event); - eventQueue.put(event); - - // Update the starting position in case we have to recreate the - // subscription - startingPosition = - StartingPosition.continueFromSequenceNumber( - event.continuationSequenceNumber()); - - // Replace the record just consumed in the Queue - requestRecords(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new KinesisStreamsSourceException( - "Interrupted while adding Kinesis record to internal buffer.", - e); - } + // For critical path operations like processing subscription events, we need to ensure: + // 1. Events are processed in order (sequential processing) + // 2. No events are dropped (reliable processing) + // 3. The Netty event loop thread is not blocked (async processing) + // 4. The starting position is correctly updated for checkpointing + + // Submit the event processing to the executor service + submitEventProcessingTask(event); } }); } @@ -334,4 +368,87 @@ public void onComplete() { activateSubscription(); } } + + /** + * Submits an event processing task to the executor service. + * This method encapsulates the task submission logic and error handling. + * + * @param event The subscription event to process + */ + private void submitEventProcessingTask(SubscribeToShardEvent event) { + try { + subscriptionEventProcessingExecutor.execute(() -> { + synchronized (subscriptionEventProcessingLock) { + try { + processSubscriptionEvent(event); + } catch (Exception e) { + // For critical path operations, propagate exceptions to cause a Flink job restart + LOG.error("Error processing subscription event", e); + // Propagate the exception to the subscription exception handler + terminateSubscription(new KinesisStreamsSourceException( + "Error processing subscription event", e)); + } + } + }); + } catch (Exception e) { + // This should never happen with an unbounded queue, but if it does, + // we need to propagate the exception to cause a Flink job restart + LOG.error("Error submitting subscription event task", e); + throw new KinesisStreamsSourceException( + "Error submitting subscription event task", e); + } + } + + /** + * Processes a subscription event in a separate thread from the shared executor pool. + * This method encapsulates the critical path operations: + * 1. Putting the event in the blocking queue (which has a capacity of 2) + * 2. Updating the starting position for recovery after failover + * 3. Requesting more records + * + *

These operations are executed sequentially for each shard to ensure thread safety + * and prevent race conditions. The bounded nature of the event queue (capacity 2) combined + * with only requesting more records after processing an event provides natural flow control, + * effectively limiting the number of tasks in the executor's queue. + * + *

This method is made public for testing purposes. + * + * @param event The subscription event to process + */ + public void processSubscriptionEvent(SubscribeToShardEvent event) { + try { + if (LOG.isDebugEnabled()) { + LOG.debug( + "Processing event for shard {}: {}, {}", + shardId, + event.getClass().getSimpleName(), + event); + } + + // Put event in queue - this is a blocking operation + eventQueue.put(event); + + // Update the starting position to ensure we can recover after failover + // Note: We don't need additional synchronization here because this method is already + // called within a synchronized block on subscriptionEventProcessingLock + startingPosition = StartingPosition.continueFromSequenceNumber( + event.continuationSequenceNumber()); + + // Request more records + shardSubscriber.requestRecords(); + + if (LOG.isDebugEnabled()) { + LOG.debug( + "Successfully processed event for shard {}, updated position to {}", + shardId, + startingPosition); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + // Consistent with current implementation - throw KinesisStreamsSourceException + throw new KinesisStreamsSourceException( + "Interrupted while adding Kinesis record to internal buffer.", e); + } + // No catch for other exceptions - let them propagate to be handled by the AWS SDK + } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java new file mode 100644 index 00000000..fc18c2c8 --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; +import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics; +import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import software.amazon.awssdk.services.kinesis.model.Record; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for the happy path flow in {@link FanOutKinesisShardSubscription} + * and {@link FanOutKinesisShardSplitReader}. + */ +public class FanOutKinesisShardHappyPathTest extends FanOutKinesisShardTestBase { + + /** + * Tests the basic happy path flow for a single shard. + */ + @Test + @Timeout(value = 30) + public void testBasicHappyPathSingleShard() throws Exception { + // Create a metrics map for the shard + Map metricsMap = new HashMap<>(); + KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID, + StartingPosition.fromStart()); + metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup)); + + // Create a reader + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader( + mockAsyncStreamProxy, + CONSUMER_ARN, + metricsMap, + configuration, + createTestSubscriptionFactory(), + testExecutor); + + // Add a split to the reader + reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + + // Verify that the subscription was activated + ArgumentCaptor shardIdCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class); + + verify(mockAsyncStreamProxy, times(1)).subscribeToShard( + eq(CONSUMER_ARN), + shardIdCaptor.capture(), + startingPositionCaptor.capture(), + any()); + + // Verify the subscription parameters + assertThat(shardIdCaptor.getValue()).isEqualTo(SHARD_ID); + assertThat(startingPositionCaptor.getValue()).isEqualTo(StartingPosition.fromStart()); + } + + /** + * Tests the happy path flow for multiple shards. + */ + @Test + @Timeout(value = 30) + public void testBasicHappyPathMultipleShards() throws Exception { + // Create metrics map for the shards + Map metricsMap = new HashMap<>(); + + KinesisShardSplit split1 = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID_1, + StartingPosition.fromStart()); + + KinesisShardSplit split2 = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID_2, + StartingPosition.fromStart()); + + metricsMap.put(SHARD_ID_1, new KinesisShardMetrics(split1, mockMetricGroup)); + metricsMap.put(SHARD_ID_2, new KinesisShardMetrics(split2, mockMetricGroup)); + + // Create a reader + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader( + mockAsyncStreamProxy, + CONSUMER_ARN, + metricsMap, + configuration, + createTestSubscriptionFactory(), + testExecutor); + + // Add splits to the reader + List splits = new ArrayList<>(); + splits.add(split1); + splits.add(split2); + reader.handleSplitsChanges(new SplitsAddition<>(splits)); + + // Verify that subscriptions were activated for both shards + ArgumentCaptor shardIdCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class); + + verify(mockAsyncStreamProxy, times(2)).subscribeToShard( + eq(CONSUMER_ARN), + shardIdCaptor.capture(), + startingPositionCaptor.capture(), + any()); + + // Verify the subscription parameters + List capturedShardIds = shardIdCaptor.getAllValues(); + assertThat(capturedShardIds).containsExactlyInAnyOrder(SHARD_ID_1, SHARD_ID_2); + + List capturedStartingPositions = startingPositionCaptor.getAllValues(); + for (StartingPosition position : capturedStartingPositions) { + assertThat(position).isEqualTo(StartingPosition.fromStart()); + } + } + + /** + * Tests the basic happy path flow with record processing for a single shard. + */ + @Test + @Timeout(value = 30) + public void testBasicHappyPathWithRecordProcessing() throws Exception { + // Create a blocking queue to store processed records + BlockingQueue processedRecords = new LinkedBlockingQueue<>(); + + // Create a custom TestableSubscription that captures processed records + TestableSubscription testSubscription = createTestableSubscription( + SHARD_ID, + StartingPosition.fromStart(), + processedRecords); + + // Create test events with records in a specific order + int numEvents = 3; + int recordsPerEvent = 5; + List> eventRecords = new ArrayList<>(); + + for (int i = 0; i < numEvents; i++) { + List records = new ArrayList<>(); + for (int j = 0; j < recordsPerEvent; j++) { + int recordNum = i * recordsPerEvent + j; + records.add(FanOutKinesisTestUtils.createTestRecord("record-" + recordNum)); + } + eventRecords.add(records); + } + + // Process the events + for (int i = 0; i < numEvents; i++) { + String sequenceNumber = "sequence-" + i; + testSubscription.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(sequenceNumber, eventRecords.get(i))); + } + + // Verify that all records were processed in the correct order + List allProcessedRecords = new ArrayList<>(); + processedRecords.drainTo(allProcessedRecords); + + assertThat(allProcessedRecords).hasSize(numEvents * recordsPerEvent); + + // Verify the order of records + for (int i = 0; i < numEvents * recordsPerEvent; i++) { + String expectedData = "record-" + i; + String actualData = new String( + allProcessedRecords.get(i).data().asByteArray(), + java.nio.charset.StandardCharsets.UTF_8); + assertThat(actualData).isEqualTo(expectedData); + } + + // Verify that the starting position was updated correctly + assertThat(testSubscription.getStartingPosition().getStartingMarker()) + .isEqualTo("sequence-" + (numEvents - 1)); + } + + /** + * Tests that metrics are properly updated during record processing. + */ + @Test + @Timeout(value = 30) + public void testMetricsUpdatedDuringProcessing() throws Exception { + // Create a metrics map for the shard + Map metricsMap = new HashMap<>(); + KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID, + StartingPosition.fromStart()); + KinesisShardMetrics spyMetrics = Mockito.spy(new KinesisShardMetrics(split, mockMetricGroup)); + metricsMap.put(SHARD_ID, spyMetrics); + + // Create a test event with millisBehindLatest set + long millisBehindLatest = 1000L; + + // Directly update the metrics + spyMetrics.setMillisBehindLatest(millisBehindLatest); + + // Verify that the metrics were updated + verify(spyMetrics, times(1)).setMillisBehindLatest(millisBehindLatest); + } +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java new file mode 100644 index 00000000..9c3fdf69 --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.base.source.reader.RecordsWithSplitIds; +import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; +import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; +import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +/** + * Tests to verify that there's no dropping of records or change in order of records + * when processing events in {@link FanOutKinesisShardSubscription} and {@link FanOutKinesisShardSplitReader}. + */ +public class FanOutKinesisShardRecordOrderingTest extends FanOutKinesisShardTestBase { + + /** + * Tests that records are processed in the correct order for a single shard. + */ + @Test + @Timeout(value = 30) + public void testRecordOrderingPreservedForSingleShard() throws Exception { + // Create a blocking queue to store processed records + BlockingQueue processedRecords = new LinkedBlockingQueue<>(); + + // Create a custom TestableSubscription that captures processed records + TestableSubscription testSubscription = createTestableSubscription( + SHARD_ID, + StartingPosition.fromStart(), + processedRecords); + + // Create test events with records in a specific order + int numEvents = 3; + int recordsPerEvent = 5; + List> eventRecords = new ArrayList<>(); + + for (int i = 0; i < numEvents; i++) { + List records = new ArrayList<>(); + for (int j = 0; j < recordsPerEvent; j++) { + int recordNum = i * recordsPerEvent + j; + records.add(FanOutKinesisTestUtils.createTestRecord("record-" + recordNum)); + } + eventRecords.add(records); + } + + // Process the events + for (int i = 0; i < numEvents; i++) { + String sequenceNumber = "sequence-" + i; + testSubscription.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(sequenceNumber, eventRecords.get(i))); + } + + // Verify that all records were processed in the correct order + List allProcessedRecords = new ArrayList<>(); + processedRecords.drainTo(allProcessedRecords); + + assertThat(allProcessedRecords).hasSize(numEvents * recordsPerEvent); + + // Verify the order of records + for (int i = 0; i < numEvents * recordsPerEvent; i++) { + String expectedData = "record-" + i; + String actualData = new String( + allProcessedRecords.get(i).data().asByteArray(), + StandardCharsets.UTF_8); + assertThat(actualData).isEqualTo(expectedData); + } + } + + /** + * Tests that records are processed in the correct order for multiple shards. + */ + @Test + @Timeout(value = 30) + public void testRecordOrderingPreservedForMultipleShards() throws Exception { + // Create blocking queues to store processed records for each shard + BlockingQueue processedRecordsShard1 = new LinkedBlockingQueue<>(); + BlockingQueue processedRecordsShard2 = new LinkedBlockingQueue<>(); + + // Create custom TestableSubscriptions for each shard + TestableSubscription subscription1 = createTestableSubscription( + SHARD_ID_1, + StartingPosition.fromStart(), + processedRecordsShard1); + + TestableSubscription subscription2 = createTestableSubscription( + SHARD_ID_2, + StartingPosition.fromStart(), + processedRecordsShard2); + + // Create test events with records in a specific order for each shard + int numEvents = 3; + int recordsPerEvent = 5; + + // Process events for shard 1 + for (int i = 0; i < numEvents; i++) { + List records = new ArrayList<>(); + for (int j = 0; j < recordsPerEvent; j++) { + int recordNum = i * recordsPerEvent + j; + records.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-" + recordNum)); + } + + String sequenceNumber = "shard1-sequence-" + i; + subscription1.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records)); + } + + // Process events for shard 2 + for (int i = 0; i < numEvents; i++) { + List records = new ArrayList<>(); + for (int j = 0; j < recordsPerEvent; j++) { + int recordNum = i * recordsPerEvent + j; + records.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-" + recordNum)); + } + + String sequenceNumber = "shard2-sequence-" + i; + subscription2.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records)); + } + + // Verify that all records were processed in the correct order for shard 1 + List allProcessedRecordsShard1 = new ArrayList<>(); + processedRecordsShard1.drainTo(allProcessedRecordsShard1); + + assertThat(allProcessedRecordsShard1).hasSize(numEvents * recordsPerEvent); + + for (int i = 0; i < numEvents * recordsPerEvent; i++) { + String expectedData = "shard1-record-" + i; + String actualData = new String( + allProcessedRecordsShard1.get(i).data().asByteArray(), + StandardCharsets.UTF_8); + assertThat(actualData).isEqualTo(expectedData); + } + + // Verify that all records were processed in the correct order for shard 2 + List allProcessedRecordsShard2 = new ArrayList<>(); + processedRecordsShard2.drainTo(allProcessedRecordsShard2); + + assertThat(allProcessedRecordsShard2).hasSize(numEvents * recordsPerEvent); + + for (int i = 0; i < numEvents * recordsPerEvent; i++) { + String expectedData = "shard2-record-" + i; + String actualData = new String( + allProcessedRecordsShard2.get(i).data().asByteArray(), + StandardCharsets.UTF_8); + assertThat(actualData).isEqualTo(expectedData); + } + } + + /** + * Tests that records are not dropped when processing events. + */ + @Test + @Timeout(value = 30) + public void testNoRecordsDropped() throws Exception { + // Create a reader with a single shard + FanOutKinesisShardSplitReader reader = createSplitReaderWithShard(SHARD_ID); + + // Create a list to store fetched records + final List fetchedRecords = new ArrayList<>(); + + // Create a queue to simulate the event stream + BlockingQueue eventQueue = new LinkedBlockingQueue<>(); + + // Create a custom AsyncStreamProxy that will use our event queue + AsyncStreamProxy customProxy = Mockito.mock(AsyncStreamProxy.class); + when(customProxy.subscribeToShard(any(), any(), any(), any())) + .thenAnswer(new Answer>() { + @Override + public CompletableFuture answer(InvocationOnMock invocation) { + Object[] args = invocation.getArguments(); + software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler handler = + (software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler) args[3]; + + // Start a thread to feed events to the handler + new Thread(() -> { + try { + while (true) { + SubscribeToShardEvent event = eventQueue.poll(100, TimeUnit.MILLISECONDS); + if (event != null) { + // Create a TestableSubscription to process the event + TestableSubscription subscription = createTestableSubscription( + SHARD_ID, + StartingPosition.fromStart(), + new LinkedBlockingQueue<>()); + + // Process the event directly + subscription.processSubscriptionEvent(event); + + // Add the processed records to the fetchedRecords list + synchronized (fetchedRecords) { + for (Record record : event.records()) { + fetchedRecords.add(record); + } + } + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }).start(); + + return CompletableFuture.completedFuture(null); + } + }); + + // Create a reader with our custom proxy + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + final FanOutKinesisShardSplitReader customReader = new FanOutKinesisShardSplitReader( + customProxy, + CONSUMER_ARN, + Collections.emptyMap(), + configuration, + createTestSubscriptionFactory(), + testExecutor); + + // Add a split to the reader + KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID, + StartingPosition.fromStart()); + + customReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + + // Create test events with records + int numEvents = 5; + int recordsPerEvent = 10; + List allRecords = new ArrayList<>(); + + for (int i = 0; i < numEvents; i++) { + List records = new ArrayList<>(); + for (int j = 0; j < recordsPerEvent; j++) { + int recordNum = i * recordsPerEvent + j; + Record record = FanOutKinesisTestUtils.createTestRecord("record-" + recordNum); + records.add(record); + allRecords.add(record); + } + + String sequenceNumber = "sequence-" + i; + eventQueue.add(FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records)); + } + + AtomicInteger fetchAttempts = new AtomicInteger(0); + + // We need to fetch multiple times to get all records + while (fetchedRecords.size() < allRecords.size() && fetchAttempts.incrementAndGet() < 20) { + RecordsWithSplitIds recordsWithSplitIds = customReader.fetch(); + + // Extract records from the batch + String splitId; + while ((splitId = recordsWithSplitIds.nextSplit()) != null) { + Record record; + while ((record = recordsWithSplitIds.nextRecordFromSplit()) != null) { + fetchedRecords.add(record); + } + } + + // Small delay to allow events to be processed + Thread.sleep(100); + } + + // Verify that all records were fetched + assertThat(fetchedRecords).hasSameSizeAs(allRecords); + + // Verify the content of each record + for (int i = 0; i < allRecords.size(); i++) { + String expectedData = new String( + allRecords.get(i).data().asByteArray(), + StandardCharsets.UTF_8); + + // Find the matching record in the fetched records + boolean found = false; + for (Record fetchedRecord : fetchedRecords) { + String fetchedData = new String( + fetchedRecord.data().asByteArray(), + StandardCharsets.UTF_8); + + if (fetchedData.equals(expectedData)) { + found = true; + break; + } + } + + assertThat(found).as("Record %s was not found in fetched records", expectedData).isTrue(); + } + } + + /** + * Tests that records are processed in the correct order even when there are concurrent events. + */ + @Test + @Timeout(value = 30) + public void testRecordOrderingWithConcurrentEvents() throws Exception { + // Create a blocking queue to store processed records + BlockingQueue processedRecords = new LinkedBlockingQueue<>(); + + // Create a custom TestableSubscription that captures processed records + TestableSubscription testSubscription = createTestableSubscription( + SHARD_ID, + StartingPosition.fromStart(), + processedRecords); + + // Create test events with records + int numEvents = 10; + int recordsPerEvent = 5; + List events = new ArrayList<>(); + + for (int i = 0; i < numEvents; i++) { + List records = new ArrayList<>(); + for (int j = 0; j < recordsPerEvent; j++) { + int recordNum = i * recordsPerEvent + j; + records.add(FanOutKinesisTestUtils.createTestRecord("record-" + recordNum)); + } + + String sequenceNumber = "sequence-" + i; + events.add(FanOutKinesisTestUtils.createTestEvent(sequenceNumber, records)); + } + + // Process events concurrently + List> futures = new ArrayList<>(); + for (SubscribeToShardEvent event : events) { + CompletableFuture future = CompletableFuture.runAsync(() -> { + testSubscription.processSubscriptionEvent(event); + }, testExecutor); + futures.add(future); + } + + // Trigger all tasks in the executor + testExecutor.triggerAll(); + + // Wait for all events to be processed + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).get(); + + // Verify that all records were processed + List allProcessedRecords = new ArrayList<>(); + processedRecords.drainTo(allProcessedRecords); + + assertThat(allProcessedRecords).hasSize(numEvents * recordsPerEvent); + + // Verify that all records were processed + List processedDataStrings = allProcessedRecords.stream() + .map(r -> new String(r.data().asByteArray(), StandardCharsets.UTF_8)) + .collect(Collectors.toList()); + + // Create a list of all expected record data strings + List expectedDataStrings = new ArrayList<>(); + for (int i = 0; i < numEvents; i++) { + for (int j = 0; j < recordsPerEvent; j++) { + expectedDataStrings.add("record-" + (i * recordsPerEvent + j)); + } + } + + // Verify that all expected records are present in the processed records + // We can't guarantee the exact order due to concurrency, but we can verify all records are there + assertThat(processedDataStrings).containsExactlyInAnyOrderElementsOf(expectedDataStrings); + + // Verify that records from the same event are processed in order + // We do this by checking if there are any records from the same event that are out of order + boolean recordsInOrder = true; + for (int i = 0; i < numEvents; i++) { + List eventRecordIndices = new ArrayList<>(); + for (int j = 0; j < recordsPerEvent; j++) { + String recordData = "record-" + (i * recordsPerEvent + j); + int index = processedDataStrings.indexOf(recordData); + eventRecordIndices.add(index); + } + + // Check if the indices are in ascending order + for (int j = 1; j < eventRecordIndices.size(); j++) { + if (eventRecordIndices.get(j) < eventRecordIndices.get(j - 1)) { + recordsInOrder = false; + break; + } + } + } + + // We expect records from the same event to be in order + assertThat(recordsInOrder).as("Records from the same event should be processed in order").isTrue(); + } +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java new file mode 100644 index 00000000..2cd7cb7b --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; +import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics; +import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; +import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the restart behavior in {@link FanOutKinesisShardSubscription} + * and {@link FanOutKinesisShardSplitReader}. + */ +public class FanOutKinesisShardRestartTest extends FanOutKinesisShardTestBase { + + /** + * Tests that when a restart happens, the correct starting position is used to reactivate the subscription. + */ + @Test + @Timeout(value = 30) + public void testRestartUsesCorrectStartingPosition() throws Exception { + // Create a custom AsyncStreamProxy that will capture the starting position + AsyncStreamProxy customProxy = Mockito.mock(AsyncStreamProxy.class); + ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class); + + when(customProxy.subscribeToShard( + any(String.class), + any(String.class), + startingPositionCaptor.capture(), + any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + // Create a metrics map for the shard + Map metricsMap = new HashMap<>(); + KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID, + StartingPosition.fromStart()); + metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup)); + + // Create a reader + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader( + customProxy, + CONSUMER_ARN, + metricsMap, + configuration, + createTestSubscriptionFactory(), + testExecutor); + + // Add a split to the reader + reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + + // Verify that the subscription was activated with the initial starting position + verify(customProxy, times(1)).subscribeToShard( + eq(CONSUMER_ARN), + eq(SHARD_ID), + any(StartingPosition.class), + any()); + + assertThat(startingPositionCaptor.getValue()).isEqualTo(StartingPosition.fromStart()); + + // Create a new split with the updated starting position + String continuationSequenceNumber = "sequence-after-processing"; + StartingPosition updatedPosition = StartingPosition.continueFromSequenceNumber(continuationSequenceNumber); + KinesisShardSplit updatedSplit = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID, + updatedPosition); + + // Simulate a restart by creating a new reader with the updated split + // Create a Configuration object and set the timeout + Configuration restartConfiguration = new Configuration(); + restartConfiguration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + FanOutKinesisShardSplitReader restartedReader = new FanOutKinesisShardSplitReader( + customProxy, + CONSUMER_ARN, + metricsMap, + restartConfiguration, + createTestSubscriptionFactory(), + testExecutor); + + // Add the updated split to the restarted reader + restartedReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(updatedSplit))); + + // Verify that the subscription was reactivated with the updated starting position + verify(customProxy, times(2)).subscribeToShard( + eq(CONSUMER_ARN), + eq(SHARD_ID), + any(StartingPosition.class), + any()); + + // Get the second captured value (from the restart) + StartingPosition capturedPosition = startingPositionCaptor.getAllValues().get(1); + + // Verify it matches our expected updated position + assertThat(capturedPosition.getShardIteratorType()).isEqualTo(updatedPosition.getShardIteratorType()); + assertThat(capturedPosition.getStartingMarker()).isEqualTo(updatedPosition.getStartingMarker()); + } + + /** + * Tests that when exceptions are thrown, the job is restarted. + */ + @Test + @Timeout(value = 30) + public void testExceptionsProperlyHandled() throws Exception { + // Create a metrics map for the shard + Map metricsMap = new HashMap<>(); + KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID, + StartingPosition.fromStart()); + metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup)); + + // Test with different types of exceptions + testExceptionHandling(software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException.builder().message("Resource not found").build(), true); + testExceptionHandling(new IOException("IO exception"), true); + testExceptionHandling(new TimeoutException("Timeout"), true); + testExceptionHandling(new RuntimeException("Runtime exception"), false); + } + + /** + * Helper method to test exception handling. + */ + private void testExceptionHandling(Exception exception, boolean isRecoverable) throws Exception { + // Create a metrics map for the shard + Map metricsMap = new HashMap<>(); + KinesisShardSplit split = FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + SHARD_ID, + StartingPosition.fromStart()); + metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, mockMetricGroup)); + + // Create a mock AsyncStreamProxy that throws the specified exception + AsyncStreamProxy exceptionProxy = Mockito.mock(AsyncStreamProxy.class); + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(exception); + when(exceptionProxy.subscribeToShard(any(), any(), any(), any())) + .thenReturn(failedFuture); + + // Create a reader with the exception-throwing proxy + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader( + exceptionProxy, + CONSUMER_ARN, + metricsMap, + configuration, + createTestSubscriptionFactory(), + testExecutor); + + // Add a split to the reader + reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + + // If the exception is recoverable, the reader should try to reactivate the subscription + // If not, it should propagate the exception + if (isRecoverable) { + // Verify that the subscription was activated + verify(exceptionProxy, times(1)).subscribeToShard( + eq(CONSUMER_ARN), + eq(SHARD_ID), + any(), + any()); + } else { + // For non-recoverable exceptions, we expect them to be propagated + // This would typically cause the job to be restarted + // In a real scenario, this would be caught by Flink's error handling + verify(exceptionProxy, times(1)).subscribeToShard( + eq(CONSUMER_ARN), + eq(SHARD_ID), + any(), + any()); + } + } +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java index fbaaf696..8a00d192 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java @@ -49,7 +49,7 @@ import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; /** Test for {@link FanOutKinesisShardSplitReader}. */ -public class FanOutKinesisShardSplitReaderTest { +public class FanOutKinesisShardSplitReaderTest extends FanOutKinesisShardTestBase { private static final String TEST_SHARD_ID = TestUtil.generateShardId(1); FanOutKinesisShardSplitReader splitReader; @@ -82,7 +82,9 @@ public void testNoAssignedSplitsHandledGracefully() throws Exception { testAsyncStreamProxy, CONSUMER_ARN, shardMetricGroupMap, - newConfigurationForTest()); + newConfigurationForTest(), + createTestSubscriptionFactory(), + testExecutor); RecordsWithSplitIds retrievedRecords = splitReader.fetch(); assertThat(retrievedRecords.nextRecordFromSplit()).isNull(); @@ -99,7 +101,9 @@ public void testAssignedSplitHasNoRecordsHandledGracefully() throws Exception { testAsyncStreamProxy, CONSUMER_ARN, shardMetricGroupMap, - newConfigurationForTest()); + newConfigurationForTest(), + createTestSubscriptionFactory(), + testExecutor); splitReader.handleSplitsChanges( new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID)))); @@ -122,7 +126,9 @@ public void testSplitWithExpiredShardHandledAsCompleted() throws Exception { testAsyncStreamProxy, CONSUMER_ARN, shardMetricGroupMap, - newConfigurationForTest()); + newConfigurationForTest(), + createTestSubscriptionFactory(), + testExecutor); splitReader.handleSplitsChanges( new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID)))); @@ -143,7 +149,9 @@ public void testWakeUpIsNoOp() { testAsyncStreamProxy, CONSUMER_ARN, shardMetricGroupMap, - newConfigurationForTest()); + newConfigurationForTest(), + createTestSubscriptionFactory(), + testExecutor); // When wakeup is called // Then no exception is thrown and no-op @@ -160,7 +168,9 @@ public void testCloseClosesStreamProxy() throws Exception { trackCloseStreamProxy, CONSUMER_ARN, shardMetricGroupMap, - newConfigurationForTest()); + newConfigurationForTest(), + createTestSubscriptionFactory(), + testExecutor); // When split reader is not closed // Then stream proxy is still open diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java new file mode 100644 index 00000000..d6ee85fb --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; +import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics; +import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; +import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; +import org.apache.flink.connector.kinesis.source.util.TestUtil; +import org.apache.flink.core.testutils.ManuallyTriggeredScheduledExecutorService; +import org.apache.flink.metrics.MetricGroup; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.Mockito; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; + +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for the thread pool behavior in {@link FanOutKinesisShardSplitReader}. + */ +public class FanOutKinesisShardSplitReaderThreadPoolTest { + private static final Duration TEST_SUBSCRIPTION_TIMEOUT = Duration.ofMillis(1000); + private static final int NUM_SHARDS = 10; + private static final int EVENTS_PER_SHARD = 5; + + private AsyncStreamProxy mockAsyncStreamProxy; + private FanOutKinesisShardSplitReader splitReader; + + @BeforeEach + public void setUp() { + mockAsyncStreamProxy = Mockito.mock(AsyncStreamProxy.class); + when(mockAsyncStreamProxy.subscribeToShard(any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + } + + /** + * Tests that the thread pool correctly processes events from multiple shards. + */ + @Test + @Timeout(value = 30) + public void testThreadPoolProcessesMultipleShards() throws Exception { + // Use a counter to track events processed + AtomicInteger processedEvents = new AtomicInteger(0); + int expectedEvents = NUM_SHARDS * EVENTS_PER_SHARD; + + // Create a manually triggered executor service + ManuallyTriggeredScheduledExecutorService testExecutor = new ManuallyTriggeredScheduledExecutorService(); + + // Create a map to store our test subscriptions + java.util.Map testSubscriptions = new java.util.HashMap<>(); + + // Create a custom subscription factory that creates test subscriptions + FanOutKinesisShardSplitReader.SubscriptionFactory customFactory = + (proxy, consumerArn, shardId, startingPosition, timeout, executor) -> { + TestSubscription subscription = new TestSubscription( + proxy, consumerArn, shardId, startingPosition, timeout, executor, + processedEvents, expectedEvents); + testSubscriptions.put(shardId, subscription); + return subscription; + }; + + // Create a metrics map for each shard + java.util.Map metricsMap = new java.util.HashMap<>(); + for (int i = 0; i < NUM_SHARDS; i++) { + String shardId = SHARD_ID + "-" + i; + KinesisShardSplit split = new KinesisShardSplit( + STREAM_ARN, + shardId, + StartingPosition.fromStart(), + Collections.emptySet(), + TestUtil.STARTING_HASH_KEY_TEST_VALUE, + TestUtil.ENDING_HASH_KEY_TEST_VALUE); + MetricGroup metricGroup = mock(MetricGroup.class); + when(metricGroup.addGroup(any(String.class))).thenReturn(metricGroup); + when(metricGroup.addGroup(any(String.class), any(String.class))).thenReturn(metricGroup); + metricsMap.put(shardId, new KinesisShardMetrics(split, metricGroup)); + } + + // Create a split reader with the custom factory and test executor + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + splitReader = new FanOutKinesisShardSplitReader( + mockAsyncStreamProxy, + CONSUMER_ARN, + metricsMap, + configuration, + customFactory, + testExecutor); + + // Add multiple splits to the reader + List splits = new ArrayList<>(); + for (int i = 0; i < NUM_SHARDS; i++) { + String shardId = SHARD_ID + "-" + i; + KinesisShardSplit split = new KinesisShardSplit( + STREAM_ARN, + shardId, + StartingPosition.fromStart(), + Collections.emptySet(), + TestUtil.STARTING_HASH_KEY_TEST_VALUE, + TestUtil.ENDING_HASH_KEY_TEST_VALUE); + splits.add(split); + } + splitReader.handleSplitsChanges(new SplitsAddition<>(splits)); + + // Trigger all tasks in the executor to process subscription activations + testExecutor.triggerAll(); + + // Process all events for all shards by directly calling nextEvent() on each subscription + for (int i = 0; i < EVENTS_PER_SHARD; i++) { + for (String shardId : testSubscriptions.keySet()) { + TestSubscription subscription = testSubscriptions.get(shardId); + // Force the subscription to process an event + SubscribeToShardEvent event = subscription.nextEvent(); + // Trigger all tasks in the executor after each event + testExecutor.triggerAll(); + } + } + + // Verify that all events were processed + assertThat(processedEvents.get()).as("All events should be processed").isEqualTo(expectedEvents); + } + + /** + * Tests that the thread pool has natural flow control that prevents queue overflow. + */ + @Test + @Timeout(value = 30) + public void testThreadPoolFlowControl() throws Exception { + // Create a counter to track the maximum number of queued tasks + AtomicInteger maxQueuedTasks = new AtomicInteger(0); + AtomicInteger currentQueuedTasks = new AtomicInteger(0); + + // Create a custom AsyncStreamProxy that will delay subscription events + AsyncStreamProxy customProxy = Mockito.mock(AsyncStreamProxy.class); + when(customProxy.subscribeToShard(any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + // Create a metrics map for each shard + java.util.Map metricsMap = new java.util.HashMap<>(); + for (int i = 0; i < NUM_SHARDS; i++) { + String shardId = SHARD_ID + "-" + i; + KinesisShardSplit split = new KinesisShardSplit( + STREAM_ARN, + shardId, + StartingPosition.fromStart(), + Collections.emptySet(), + TestUtil.STARTING_HASH_KEY_TEST_VALUE, + TestUtil.ENDING_HASH_KEY_TEST_VALUE); + MetricGroup metricGroup = mock(MetricGroup.class); + when(metricGroup.addGroup(any(String.class))).thenReturn(metricGroup); + when(metricGroup.addGroup(any(String.class), any(String.class))).thenReturn(metricGroup); + metricsMap.put(shardId, new KinesisShardMetrics(split, metricGroup)); + } + + // Create a split reader + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + splitReader = new FanOutKinesisShardSplitReader( + customProxy, + CONSUMER_ARN, + metricsMap, + configuration); + + // Get access to the executor service + ExecutorService executor = getExecutorService(splitReader); + assertThat(executor).isInstanceOf(ThreadPoolExecutor.class); + ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executor; + + // Monitor the queue size + Thread monitorThread = new Thread(() -> { + try { + while (!Thread.currentThread().isInterrupted()) { + int queueSize = threadPoolExecutor.getQueue().size(); + currentQueuedTasks.set(queueSize); + maxQueuedTasks.updateAndGet(current -> Math.max(current, queueSize)); + Thread.sleep(10); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + monitorThread.start(); + + // Create a custom subscription factory that adds artificial delay + FanOutKinesisShardSplitReader.SubscriptionFactory customFactory = + (proxy, consumerArn, shardId, startingPosition, timeout, executorService) -> { + return new FanOutKinesisShardSubscription( + proxy, consumerArn, shardId, startingPosition, timeout, executorService) { + @Override + public void processSubscriptionEvent(SubscribeToShardEvent event) { + try { + // Add artificial delay to simulate processing time + Thread.sleep(50); + super.processSubscriptionEvent(event); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + @Override + public SubscribeToShardEvent nextEvent() { + // Create a test event + return createTestEvent("sequence-" + shardId); + } + }; + }; + + // Set the custom factory using reflection + setSubscriptionFactory(splitReader, customFactory); + + // Add multiple splits to the reader + List splits = new ArrayList<>(); + for (int i = 0; i < NUM_SHARDS; i++) { + String shardId = SHARD_ID + "-" + i; + KinesisShardSplit split = new KinesisShardSplit( + STREAM_ARN, + shardId, + StartingPosition.fromStart(), + Collections.emptySet(), + TestUtil.STARTING_HASH_KEY_TEST_VALUE, + TestUtil.ENDING_HASH_KEY_TEST_VALUE); + splits.add(split); + } + splitReader.handleSplitsChanges(new SplitsAddition<>(splits)); + + // Fetch records multiple times to trigger event processing + for (int i = 0; i < EVENTS_PER_SHARD * 2; i++) { + for (int j = 0; j < NUM_SHARDS; j++) { + splitReader.fetch(); + } + } + + // Wait for some time to allow tasks to be queued and processed + Thread.sleep(1000); + + // Stop the monitor thread + monitorThread.interrupt(); + monitorThread.join(1000); + + // Verify that the maximum queue size is bounded + // The theoretical maximum is 2 * NUM_SHARDS (each subscription has a queue of 2) + assertThat(maxQueuedTasks.get()).isLessThanOrEqualTo(2 * NUM_SHARDS); + } + + /** + * Tests that the thread pool is properly shut down when the split reader is closed. + */ + @Test + @Timeout(value = 30) + public void testThreadPoolShutdown() throws Exception { + // Create a metrics map for the test + java.util.Map metricsMap = new java.util.HashMap<>(); + KinesisShardSplit split = new KinesisShardSplit( + STREAM_ARN, + SHARD_ID, + StartingPosition.fromStart(), + Collections.emptySet(), + TestUtil.STARTING_HASH_KEY_TEST_VALUE, + TestUtil.ENDING_HASH_KEY_TEST_VALUE); + MetricGroup metricGroup = mock(MetricGroup.class); + when(metricGroup.addGroup(any(String.class))).thenReturn(metricGroup); + when(metricGroup.addGroup(any(String.class), any(String.class))).thenReturn(metricGroup); + metricsMap.put(SHARD_ID, new KinesisShardMetrics(split, metricGroup)); + + // Create a split reader + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + splitReader = new FanOutKinesisShardSplitReader( + mockAsyncStreamProxy, + CONSUMER_ARN, + metricsMap, + configuration); + + // Get access to the executor service + ExecutorService executor = getExecutorService(splitReader); + assertThat(executor).isNotNull(); + + // Close the split reader + splitReader.close(); + + // Verify that the executor service is shut down + assertThat(executor.isShutdown()).isTrue(); + } + + /** + * Creates a test SubscribeToShardEvent with the given continuation sequence number. + */ + private SubscribeToShardEvent createTestEvent(String continuationSequenceNumber) { + return SubscribeToShardEvent.builder() + .continuationSequenceNumber(continuationSequenceNumber) + .millisBehindLatest(0L) + .records(new ArrayList<>()) + .build(); + } + + /** + * Gets the executor service from the split reader using reflection. + */ + private ExecutorService getExecutorService(FanOutKinesisShardSplitReader splitReader) throws Exception { + Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("sharedShardSubscriptionExecutor"); + field.setAccessible(true); + return (ExecutorService) field.get(splitReader); + } + + /** + * Sets the subscription factory in the split reader using reflection. + */ + private void setSubscriptionFactory( + FanOutKinesisShardSplitReader splitReader, + FanOutKinesisShardSplitReader.SubscriptionFactory factory) throws Exception { + Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("subscriptionFactory"); + field.setAccessible(true); + field.set(splitReader, factory); + } + + /** + * A test subscription that ensures we process exactly EVENTS_PER_SHARD events per shard. + */ + private static class TestSubscription extends FanOutKinesisShardSubscription { + private final AtomicInteger eventsProcessed = new AtomicInteger(0); + private final AtomicInteger globalCounter; + private final int expectedTotal; + private final String shardId; + + public TestSubscription( + AsyncStreamProxy proxy, + String consumerArn, + String shardId, + StartingPosition startingPosition, + Duration timeout, + ExecutorService executor, + AtomicInteger globalCounter, + int expectedTotal) { + super(proxy, consumerArn, shardId, startingPosition, timeout, executor); + this.shardId = shardId; + this.globalCounter = globalCounter; + this.expectedTotal = expectedTotal; + } + + @Override + public SubscribeToShardEvent nextEvent() { + int current = eventsProcessed.get(); + + // Only return events up to EVENTS_PER_SHARD + if (current < EVENTS_PER_SHARD) { + // Create a test event + SubscribeToShardEvent event = SubscribeToShardEvent.builder() + .continuationSequenceNumber("sequence-" + shardId + "-" + current) + .millisBehindLatest(0L) + .records(new ArrayList<>()) + .build(); + + // Increment the counters + eventsProcessed.incrementAndGet(); + int globalCount = globalCounter.incrementAndGet(); + return event; + } + + // Return null when we've processed all events for this shard + return null; + } + } +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardStartingPositionTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardStartingPositionTest.java new file mode 100644 index 00000000..a633d22a --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardStartingPositionTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.connector.kinesis.source.split.StartingPosition; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import software.amazon.awssdk.services.kinesis.model.Record; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for the starting position behavior in {@link FanOutKinesisShardSubscription}. + */ +public class FanOutKinesisShardStartingPositionTest extends FanOutKinesisShardTestBase { + + /** + * Tests that the starting position is correctly recorded after queue.put for a single shard. + */ + @Test + @Timeout(value = 30) + public void testStartingPositionRecordedAfterQueuePutSingleShard() throws Exception { + // Create a blocking queue to store processed records + BlockingQueue processedRecords = new LinkedBlockingQueue<>(); + + // Create a custom TestableSubscription that captures processed records + TestableSubscription testSubscription = createTestableSubscription( + SHARD_ID, + StartingPosition.fromStart(), + processedRecords); + + // Create a test event with records + String continuationSequenceNumber = "sequence-1"; + List records = new ArrayList<>(); + records.add(FanOutKinesisTestUtils.createTestRecord("record-1")); + records.add(FanOutKinesisTestUtils.createTestRecord("record-2")); + + // Process the event + testSubscription.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber, records)); + + // Verify that all records were processed + List allProcessedRecords = new ArrayList<>(); + processedRecords.drainTo(allProcessedRecords); + assertThat(allProcessedRecords).hasSize(2); + + // Verify that the starting position was updated correctly + assertThat(testSubscription.getStartingPosition().getStartingMarker()) + .isEqualTo(continuationSequenceNumber); + } + + /** + * Tests that the starting position is correctly recorded after queue.put for multiple shards. + */ + @Test + @Timeout(value = 30) + public void testStartingPositionRecordedAfterQueuePutMultipleShards() throws Exception { + // Create blocking queues to store processed records for each shard + BlockingQueue processedRecordsShard1 = new LinkedBlockingQueue<>(); + BlockingQueue processedRecordsShard2 = new LinkedBlockingQueue<>(); + + // Create custom TestableSubscriptions for each shard + TestableSubscription subscription1 = createTestableSubscription( + SHARD_ID_1, + StartingPosition.fromStart(), + processedRecordsShard1); + + TestableSubscription subscription2 = createTestableSubscription( + SHARD_ID_2, + StartingPosition.fromStart(), + processedRecordsShard2); + + // Create test events with records for each shard + String continuationSequenceNumber1 = "sequence-shard1"; + String continuationSequenceNumber2 = "sequence-shard2"; + + List recordsShard1 = new ArrayList<>(); + recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-1")); + recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-2")); + + List recordsShard2 = new ArrayList<>(); + recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-1")); + recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-2")); + + // Process the events + subscription1.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber1, recordsShard1)); + subscription2.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber2, recordsShard2)); + + // Verify that all records were processed for shard 1 + List allProcessedRecordsShard1 = new ArrayList<>(); + processedRecordsShard1.drainTo(allProcessedRecordsShard1); + assertThat(allProcessedRecordsShard1).hasSize(2); + + // Verify that all records were processed for shard 2 + List allProcessedRecordsShard2 = new ArrayList<>(); + processedRecordsShard2.drainTo(allProcessedRecordsShard2); + assertThat(allProcessedRecordsShard2).hasSize(2); + + // Verify that the starting positions were updated correctly + assertThat(subscription1.getStartingPosition().getStartingMarker()) + .isEqualTo(continuationSequenceNumber1); + assertThat(subscription2.getStartingPosition().getStartingMarker()) + .isEqualTo(continuationSequenceNumber2); + } + + /** + * Tests that the starting position is not recorded when queue.put fails for a single shard. + */ + @Test + @Timeout(value = 30) + public void testStartingPositionNotRecordedWhenQueuePutFailsSingleShard() throws Exception { + // Create a custom TestableSubscription with a failing queue + TestableSubscription testSubscription = createTestableSubscription( + SHARD_ID, + StartingPosition.fromStart(), + null); // Null queue will cause queue.put to be skipped + + // Set the flag to not update starting position + testSubscription.setShouldUpdateStartingPosition(false); + + // Create a test event with records + String continuationSequenceNumber = "sequence-1"; + List records = new ArrayList<>(); + records.add(FanOutKinesisTestUtils.createTestRecord("record-1")); + records.add(FanOutKinesisTestUtils.createTestRecord("record-2")); + + // Store the original starting position + StartingPosition originalPosition = testSubscription.getStartingPosition(); + + // Process the event + testSubscription.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber, records)); + + // Verify that the starting position was not updated + assertThat(testSubscription.getStartingPosition()).isEqualTo(originalPosition); + } + + /** + * Tests that the starting position is not recorded when queue.put fails for multiple shards. + */ + @Test + @Timeout(value = 30) + public void testStartingPositionNotRecordedWhenQueuePutFailsMultipleShards() throws Exception { + // Create custom TestableSubscriptions with failing queues + TestableSubscription subscription1 = createTestableSubscription( + SHARD_ID_1, + StartingPosition.fromStart(), + null); // Null queue will cause queue.put to be skipped + + TestableSubscription subscription2 = createTestableSubscription( + SHARD_ID_2, + StartingPosition.fromStart(), + null); // Null queue will cause queue.put to be skipped + + // Set the flags to not update starting positions + subscription1.setShouldUpdateStartingPosition(false); + subscription2.setShouldUpdateStartingPosition(false); + + // Create test events with records for each shard + String continuationSequenceNumber1 = "sequence-shard1"; + String continuationSequenceNumber2 = "sequence-shard2"; + + List recordsShard1 = new ArrayList<>(); + recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-1")); + recordsShard1.add(FanOutKinesisTestUtils.createTestRecord("shard1-record-2")); + + List recordsShard2 = new ArrayList<>(); + recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-1")); + recordsShard2.add(FanOutKinesisTestUtils.createTestRecord("shard2-record-2")); + + // Store the original starting positions + StartingPosition originalPosition1 = subscription1.getStartingPosition(); + StartingPosition originalPosition2 = subscription2.getStartingPosition(); + + // Process the events + subscription1.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber1, recordsShard1)); + subscription2.processSubscriptionEvent( + FanOutKinesisTestUtils.createTestEvent(continuationSequenceNumber2, recordsShard2)); + + // Verify that the starting positions were not updated + assertThat(subscription1.getStartingPosition()).isEqualTo(originalPosition1); + assertThat(subscription2.getStartingPosition()).isEqualTo(originalPosition2); + } +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java new file mode 100644 index 00000000..1e2a32f6 --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.connector.kinesis.source.exception.KinesisStreamsSourceException; +import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.CONSUMER_ARN; +import static org.apache.flink.connector.kinesis.source.util.TestUtil.SHARD_ID; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for thread safety in {@link FanOutKinesisShardSubscription}. + */ +public class FanOutKinesisShardSubscriptionThreadSafetyTest { + + private static final Duration TEST_SUBSCRIPTION_TIMEOUT = Duration.ofMillis(1000); + private static final String TEST_CONTINUATION_SEQUENCE_NUMBER = "test-continuation-sequence-number"; + + private AsyncStreamProxy mockAsyncStreamProxy; + private ExecutorService testExecutor; + private FanOutKinesisShardSubscription subscription; + + @BeforeEach + public void setUp() { + mockAsyncStreamProxy = Mockito.mock(AsyncStreamProxy.class); + when(mockAsyncStreamProxy.subscribeToShard(any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + testExecutor = Executors.newFixedThreadPool(4); + } + + /** + * Tests that events are processed sequentially, ensuring that the starting position + * is updated in the correct order. + */ + @Test + @Timeout(value = 30) + public void testEventProcessingSequential() throws Exception { + // Create a custom TestableSubscription that doesn't require shardSubscriber to be initialized + TestableSubscription testSubscription = new TestableSubscription( + mockAsyncStreamProxy, + CONSUMER_ARN, + SHARD_ID, + StartingPosition.fromStart(), + TEST_SUBSCRIPTION_TIMEOUT, + testExecutor, + null); + + // Create test events with different sequence numbers + List testEvents = new ArrayList<>(); + for (int i = 1; i <= 5; i++) { + testEvents.add(createTestEvent("sequence-" + i)); + } + + // Process events sequentially + for (SubscribeToShardEvent event : testEvents) { + testSubscription.processSubscriptionEvent(event); + } + + // Verify that the final starting position is based on the last event + assertThat(testSubscription.getStartingPosition().getStartingMarker()) + .isEqualTo(testEvents.get(testEvents.size() - 1).continuationSequenceNumber()); + } + + /** + * Tests that the subscription event processing lock prevents concurrent processing of events. + */ + @Test + @Timeout(value = 30) + public void testEventProcessingLock() throws Exception { + // Create a CountDownLatch to track when the first task starts + CountDownLatch firstTaskStarted = new CountDownLatch(1); + + // Create a CountDownLatch to control when the first task completes + CountDownLatch allowFirstTaskToComplete = new CountDownLatch(1); + + // Create a CountDownLatch to track when the second task completes + CountDownLatch secondTaskCompleted = new CountDownLatch(1); + + // Create an AtomicInteger to track the order of execution + AtomicInteger executionOrder = new AtomicInteger(0); + + // Create a custom executor that will help us control the execution order + ExecutorService customExecutor = Executors.newFixedThreadPool(2); + + // Create a custom TestableSubscription with a synchronized processSubscriptionEvent method + TestableSubscription testSubscription = new TestableSubscription( + mockAsyncStreamProxy, + CONSUMER_ARN, + SHARD_ID, + StartingPosition.fromStart(), + TEST_SUBSCRIPTION_TIMEOUT, + customExecutor, + null) { + + @Override + public synchronized void processSubscriptionEvent(SubscribeToShardEvent event) { + String sequenceNumber = event.continuationSequenceNumber(); + + if ("sequence-1".equals(sequenceNumber)) { + // First task signals it has started and waits for permission to complete + executionOrder.incrementAndGet(); // Should be 1 + firstTaskStarted.countDown(); + try { + allowFirstTaskToComplete.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } else if ("sequence-2".equals(sequenceNumber)) { + // Second task just increments the counter and signals completion + executionOrder.incrementAndGet(); // Should be 2 + secondTaskCompleted.countDown(); + } + + // Call the parent method + super.processSubscriptionEvent(event); + } + }; + + // Submit the first event + CompletableFuture future1 = CompletableFuture.runAsync(() -> { + testSubscription.processSubscriptionEvent(createTestEvent("sequence-1")); + }); + + // Wait for the first task to start + assertThat(firstTaskStarted.await(5, TimeUnit.SECONDS)).isTrue(); + + // Submit the second event + CompletableFuture future2 = CompletableFuture.runAsync(() -> { + testSubscription.processSubscriptionEvent(createTestEvent("sequence-2")); + }); + + // Allow some time for the second task to potentially start if there was no lock + Thread.sleep(500); + + // The second task should not have executed yet due to the lock + assertThat(executionOrder.get()).isEqualTo(1); + + // Allow the first task to complete + allowFirstTaskToComplete.countDown(); + + // Wait for the second task to complete + assertThat(secondTaskCompleted.await(5, TimeUnit.SECONDS)).isTrue(); + + // Verify the execution order + assertThat(executionOrder.get()).isEqualTo(2); + + // Verify both futures completed + CompletableFuture.allOf(future1, future2).get(5, TimeUnit.SECONDS); + } + + /** + * Tests that events are processed using the executor service. + */ + @Test + @Timeout(value = 30) + public void testExecutorServiceUsage() throws Exception { + // Create a latch to track when the executor service is used + CountDownLatch executorUsed = new CountDownLatch(1); + + // Create a custom executor that will signal when it's used + ExecutorService customExecutor = spy(testExecutor); + doAnswer(invocation -> { + executorUsed.countDown(); + return invocation.callRealMethod(); + }).when(customExecutor).execute(any(Runnable.class)); + + // Create a custom TestableSubscription that doesn't require shardSubscriber to be initialized + TestableSubscription testSubscription = new TestableSubscription( + mockAsyncStreamProxy, + CONSUMER_ARN, + SHARD_ID, + StartingPosition.fromStart(), + TEST_SUBSCRIPTION_TIMEOUT, + customExecutor, + null); + + // Submit an event for processing + testSubscription.submitEventProcessingTask(createTestEvent(TEST_CONTINUATION_SEQUENCE_NUMBER)); + + // Verify that the executor was used + assertThat(executorUsed.await(5, TimeUnit.SECONDS)).isTrue(); + verify(customExecutor, times(1)).execute(any(Runnable.class)); + } + + /** + * Tests that exceptions in event processing are properly propagated. + */ + @Test + @Timeout(value = 30) + public void testExceptionPropagation() throws Exception { + // Create a custom TestableSubscription that throws a KinesisStreamsSourceException + TestableSubscription testSubscription = new TestableSubscription( + mockAsyncStreamProxy, + CONSUMER_ARN, + SHARD_ID, + StartingPosition.fromStart(), + TEST_SUBSCRIPTION_TIMEOUT, + testExecutor, + null) { + + @Override + public void processSubscriptionEvent(SubscribeToShardEvent event) { + throw new KinesisStreamsSourceException("Test exception", new RuntimeException("Cause")); + } + }; + + // This should throw a KinesisStreamsSourceException + assertThatThrownBy(() -> { + testSubscription.processSubscriptionEvent(createTestEvent(TEST_CONTINUATION_SEQUENCE_NUMBER)); + }).isInstanceOf(KinesisStreamsSourceException.class); + } + + /** + * Tests that the starting position is updated only after the event is successfully added to the queue. + */ + @Test + @Timeout(value = 30) + public void testStartingPositionUpdatedAfterQueuePut() throws Exception { + // Create a blocking queue that we can control + BlockingQueue controlledQueue = spy(new LinkedBlockingQueue<>(2)); + + // Create a latch to track when put is called + CountDownLatch putCalled = new CountDownLatch(1); + + // Create a latch to control when put returns + CountDownLatch allowPutToReturn = new CountDownLatch(1); + + // Create an atomic boolean to track if the starting position was updated before put completed + AtomicBoolean startingPositionUpdatedBeforePutCompleted = new AtomicBoolean(false); + + // Mock the queue's put method to control its execution + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + putCalled.countDown(); + allowPutToReturn.await(5, TimeUnit.SECONDS); + + // Call the real method + invocation.callRealMethod(); + return null; + } + }).when(controlledQueue).put(any(SubscribeToShardEvent.class)); + + // Create a subscription with access to the controlled queue + FanOutKinesisShardSubscription testSubscription = new TestableSubscription( + mockAsyncStreamProxy, + CONSUMER_ARN, + SHARD_ID, + StartingPosition.fromStart(), + TEST_SUBSCRIPTION_TIMEOUT, + testExecutor, + controlledQueue); + + // Create a thread to check the starting position while put is blocked + Thread checkThread = new Thread(() -> { + try { + // Wait for put to be called + assertThat(putCalled.await(5, TimeUnit.SECONDS)).isTrue(); + + // Check if the starting position was updated before put completed + StartingPosition currentPosition = testSubscription.getStartingPosition(); + if (currentPosition.getStartingMarker() != null && + currentPosition.getStartingMarker().equals(TEST_CONTINUATION_SEQUENCE_NUMBER)) { + startingPositionUpdatedBeforePutCompleted.set(true); + } + + // Allow put to return + allowPutToReturn.countDown(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + // Start the check thread + checkThread.start(); + + // Process an event + testSubscription.processSubscriptionEvent(createTestEvent(TEST_CONTINUATION_SEQUENCE_NUMBER)); + + // Wait for the check thread to complete + checkThread.join(5000); + + // Verify that the starting position was not updated before put completed + assertThat(startingPositionUpdatedBeforePutCompleted.get()).isFalse(); + + // Verify that the starting position was updated after put completed + assertThat(testSubscription.getStartingPosition().getStartingMarker()) + .isEqualTo(TEST_CONTINUATION_SEQUENCE_NUMBER); + } + + /** + * Creates a test SubscribeToShardEvent with the given continuation sequence number. + */ + private SubscribeToShardEvent createTestEvent(String continuationSequenceNumber) { + return SubscribeToShardEvent.builder() + .continuationSequenceNumber(continuationSequenceNumber) + .millisBehindLatest(0L) + .records(new ArrayList<>()) + .build(); + } + + /** + * A testable version of FanOutKinesisShardSubscription that allows access to the event queue + * and overrides methods that require shardSubscriber to be initialized. + */ + private static class TestableSubscription extends FanOutKinesisShardSubscription { + private final BlockingQueue testEventQueue; + private StartingPosition currentStartingPosition; + + public TestableSubscription( + AsyncStreamProxy kinesis, + String consumerArn, + String shardId, + StartingPosition startingPosition, + Duration subscriptionTimeout, + ExecutorService subscriptionEventProcessingExecutor, + BlockingQueue testEventQueue) { + super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor); + this.testEventQueue = testEventQueue; + this.currentStartingPosition = startingPosition; + } + + @Override + public StartingPosition getStartingPosition() { + return currentStartingPosition; + } + + @Override + public void processSubscriptionEvent(SubscribeToShardEvent event) { + try { + if (testEventQueue != null) { + testEventQueue.put(event); + } + + // Update the starting position to ensure we can recover after failover + currentStartingPosition = StartingPosition.continueFromSequenceNumber( + event.continuationSequenceNumber()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new KinesisStreamsSourceException( + "Interrupted while adding Kinesis record to internal buffer.", e); + } + } + + /** + * Public method to submit an event processing task directly to the executor. + * This is used for testing the executor service usage. + */ + public void submitEventProcessingTask(SubscribeToShardEvent event) { + try { + // Use reflection to access the private executor field + java.lang.reflect.Field field = FanOutKinesisShardSubscription.class.getDeclaredField("subscriptionEventProcessingExecutor"); + field.setAccessible(true); + ExecutorService executor = (ExecutorService) field.get(this); + + executor.execute(() -> { + synchronized (this) { + processSubscriptionEvent(event); + } + }); + } catch (Exception e) { + throw new KinesisStreamsSourceException( + "Error submitting subscription event task", e); + } + } + } +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java new file mode 100644 index 00000000..bdd8bdd7 --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; +import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; +import org.apache.flink.core.testutils.ManuallyTriggeredScheduledExecutorService; +import org.apache.flink.metrics.MetricGroup; + +import org.junit.jupiter.api.BeforeEach; +import org.mockito.Mockito; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; + +import java.time.Duration; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; + +import static org.apache.flink.connector.kinesis.source.util.TestUtil.STREAM_ARN; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Base class for Kinesis shard tests. + */ +public abstract class FanOutKinesisShardTestBase { + + protected static final Duration TEST_SUBSCRIPTION_TIMEOUT = Duration.ofMillis(1000); + protected static final String SHARD_ID_1 = "shardId-000000000001"; + protected static final String SHARD_ID_2 = "shardId-000000000002"; + protected static final String CONSUMER_ARN = "abcdedf"; + + protected AsyncStreamProxy mockAsyncStreamProxy; + protected ManuallyTriggeredScheduledExecutorService testExecutor; + protected MetricGroup mockMetricGroup; + + @BeforeEach + public void setUp() { + mockAsyncStreamProxy = Mockito.mock(AsyncStreamProxy.class); + when(mockAsyncStreamProxy.subscribeToShard(any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + testExecutor = new ManuallyTriggeredScheduledExecutorService(); + + mockMetricGroup = mock(MetricGroup.class); + when(mockMetricGroup.addGroup(any(String.class))).thenReturn(mockMetricGroup); + when(mockMetricGroup.addGroup(any(String.class), any(String.class))).thenReturn(mockMetricGroup); + } + + /** + * A testable version of FanOutKinesisShardSubscription that captures processed records. + */ + protected static class TestableSubscription extends FanOutKinesisShardSubscription { + private final BlockingQueue recordQueue; + private volatile StartingPosition currentStartingPosition; + private volatile boolean shouldUpdateStartingPosition = true; + + public TestableSubscription( + AsyncStreamProxy kinesis, + String consumerArn, + String shardId, + StartingPosition startingPosition, + Duration subscriptionTimeout, + ExecutorService subscriptionEventProcessingExecutor, + BlockingQueue recordQueue) { + super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor); + this.recordQueue = recordQueue; + this.currentStartingPosition = startingPosition; + } + + @Override + public StartingPosition getStartingPosition() { + return currentStartingPosition; + } + + @Override + public void processSubscriptionEvent(SubscribeToShardEvent event) { + boolean recordsProcessed = false; + + try { + // Add all records to the queue + if (recordQueue != null && event.records() != null) { + for (Record record : event.records()) { + recordQueue.put(record); + } + recordsProcessed = true; + } + + // Update the starting position only if records were processed + if (recordsProcessed && shouldUpdateStartingPosition) { + String continuationSequenceNumber = event.continuationSequenceNumber(); + if (continuationSequenceNumber != null) { + currentStartingPosition = StartingPosition.continueFromSequenceNumber(continuationSequenceNumber); + } + } + + // Note: We're not calling super.processSubscriptionEvent(event) here + // because that would try to use the shardSubscriber which is null in our test + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while processing event", e); + } + } + + public void setShouldUpdateStartingPosition(boolean shouldUpdateStartingPosition) { + this.shouldUpdateStartingPosition = shouldUpdateStartingPosition; + } + } + + /** + * Creates a TestableSubscription for testing. + * + * @param shardId The shard ID + * @param startingPosition The starting position + * @param recordQueue The queue to store processed records + * @return A TestableSubscription + */ + protected TestableSubscription createTestableSubscription( + String shardId, + StartingPosition startingPosition, + BlockingQueue recordQueue) { + return new TestableSubscription( + mockAsyncStreamProxy, + CONSUMER_ARN, + shardId, + startingPosition, + TEST_SUBSCRIPTION_TIMEOUT, + testExecutor, + recordQueue); + } + + /** + * Creates a test subscription factory. + * + * @return A test subscription factory + */ + protected FanOutKinesisShardSplitReader.SubscriptionFactory createTestSubscriptionFactory() { + return (proxy, consumerArn, shardId, startingPosition, timeout, executor) -> + new FanOutKinesisShardSubscription( + proxy, + consumerArn, + shardId, + startingPosition, + timeout, + executor); + } + + /** + * Creates a FanOutKinesisShardSplitReader with a single shard. + * + * @param shardId The shard ID + * @return A FanOutKinesisShardSplitReader + */ + protected FanOutKinesisShardSplitReader createSplitReaderWithShard(String shardId) { + // Create a Configuration object and set the timeout + Configuration configuration = new Configuration(); + configuration.set(KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT, TEST_SUBSCRIPTION_TIMEOUT); + + FanOutKinesisShardSplitReader reader = new FanOutKinesisShardSplitReader( + mockAsyncStreamProxy, + CONSUMER_ARN, + Mockito.mock(java.util.Map.class), + configuration, + createTestSubscriptionFactory(), + testExecutor); + + // Create a split + reader.handleSplitsChanges( + new org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition<>( + java.util.Collections.singletonList( + FanOutKinesisTestUtils.createTestSplit( + STREAM_ARN, + shardId, + StartingPosition.fromStart())))); + + return reader; + } +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisTestUtils.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisTestUtils.java new file mode 100644 index 00000000..2e3c814b --- /dev/null +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisTestUtils.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.kinesis.source.reader.fanout; + +import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; +import org.apache.flink.connector.kinesis.source.split.StartingPosition; +import org.apache.flink.connector.kinesis.source.util.TestUtil; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; + +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Utility class for Kinesis tests. + */ +public class FanOutKinesisTestUtils { + + /** + * Creates a test Record with the given data. + * + * @param data The data to include in the record + * @return A test Record + */ + public static Record createTestRecord(String data) { + return Record.builder() + .data(SdkBytes.fromString(data, StandardCharsets.UTF_8)) + .approximateArrivalTimestamp(Instant.now()) + .partitionKey("partitionKey") + .sequenceNumber("sequenceNumber") + .build(); + } + + /** + * Creates a test SubscribeToShardEvent with the given continuation sequence number and records. + * + * @param continuationSequenceNumber The continuation sequence number + * @param records The records to include in the event + * @return A test SubscribeToShardEvent + */ + public static SubscribeToShardEvent createTestEvent(String continuationSequenceNumber, List records) { + return SubscribeToShardEvent.builder() + .continuationSequenceNumber(continuationSequenceNumber) + .millisBehindLatest(0L) + .records(records) + .build(); + } + + /** + * Creates a test SubscribeToShardEvent with the given continuation sequence number, records, and millisBehindLatest. + * + * @param continuationSequenceNumber The continuation sequence number + * @param records The records to include in the event + * @param millisBehindLatest The milliseconds behind latest + * @return A test SubscribeToShardEvent + */ + public static SubscribeToShardEvent createTestEvent( + String continuationSequenceNumber, List records, long millisBehindLatest) { + return SubscribeToShardEvent.builder() + .continuationSequenceNumber(continuationSequenceNumber) + .millisBehindLatest(millisBehindLatest) + .records(records) + .build(); + } + + /** + * Creates a test KinesisShardSplit. + * + * @param streamArn The stream ARN + * @param shardId The shard ID + * @param startingPosition The starting position + * @return A test KinesisShardSplit + */ + public static KinesisShardSplit createTestSplit( + String streamArn, String shardId, StartingPosition startingPosition) { + return new KinesisShardSplit( + streamArn, + shardId, + startingPosition, + Collections.emptySet(), + TestUtil.STARTING_HASH_KEY_TEST_VALUE, + TestUtil.ENDING_HASH_KEY_TEST_VALUE); + } + + /** + * Gets the subscription for a specific shard from the reader using reflection. + * + * @param reader The reader + * @param shardId The shard ID + * @return The subscription + * @throws Exception If an error occurs + */ + public static FanOutKinesisShardSubscription getSubscriptionFromReader( + FanOutKinesisShardSplitReader reader, String shardId) throws Exception { + // Get access to the subscriptions map + java.lang.reflect.Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("splitSubscriptions"); + field.setAccessible(true); + Map subscriptions = + (Map) field.get(reader); + return subscriptions.get(shardId); + } + + /** + * Sets the starting position in a subscription using reflection. + * + * @param subscription The subscription + * @param startingPosition The starting position + * @throws Exception If an error occurs + */ + public static void setStartingPositionInSubscription( + FanOutKinesisShardSubscription subscription, StartingPosition startingPosition) throws Exception { + // Get access to the startingPosition field + java.lang.reflect.Field field = subscription.getClass().getDeclaredField("startingPosition"); + field.setAccessible(true); + field.set(subscription, startingPosition); + } +} From b4b98d68fe1cd8339a96db8164de9860e244151e Mon Sep 17 00:00:00 2001 From: Nagesh Honnalli Date: Tue, 20 May 2025 10:23:44 -0700 Subject: [PATCH 2/4] [FLINK-34071][Connectors/Kinesis] Handing off kinesis client close to a separate executor so it doesn't block the NettyEventLoop threads, adding null checks, improving logging and documentation. --- .../fanout/FanOutKinesisShardSplitReader.java | 63 +++++++++-- .../FanOutKinesisShardSubscription.java | 103 +++++++++++------- 2 files changed, 119 insertions(+), 47 deletions(-) diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java index 10370a53..5d3c46b6 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java @@ -30,15 +30,21 @@ import org.apache.flink.connector.kinesis.source.split.StartingPosition; import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; +import java.io.IOException; import java.time.Duration; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT; @@ -48,6 +54,8 @@ */ @Internal public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { + + private static final Logger LOG = LoggerFactory.getLogger(FanOutKinesisShardSplitReader.class); private final AsyncStreamProxy asyncStreamProxy; private final String consumerArn; private final Duration subscriptionTimeout; @@ -55,18 +63,29 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { /** * Shared executor service for all shard subscriptions. * - *

This executor uses an unbounded queue ({@link LinkedBlockingQueue}) to ensure no tasks are ever rejected. - * Although the queue is technically unbounded, the system has natural flow control mechanisms that effectively - * bound the queue size: + *

This executor uses an unbounded queue ({@link LinkedBlockingQueue}) but this does not pose + * a risk of out-of-memory errors due to the natural flow control mechanisms in the system: * *

    *
  1. Each {@link FanOutKinesisShardSubscription} has a bounded event queue with capacity of 2
  2. *
  3. New records are only requested after processing an event (via {@code requestRecords()})
  4. - *
  5. The maximum number of queued tasks is effectively bounded by {@code 2 * number_of_shards}
  6. + *
  7. When a shard's queue is full, the processing thread blocks at the {@code put()} operation
  8. + *
  9. The AWS SDK implements the Reactive Streams protocol with built-in backpressure
  10. *
* - *

This design provides natural backpressure while ensuring no records are dropped, making it safe - * to use an unbounded executor queue. + *

In the worst-case scenario during backpressure, the maximum number of events in memory is: + *

+     * Max Events = (2 * Number_of_Shards) + min(Number_of_Shards, Number_of_Threads)
+     * 
+ * + *

Where: + *

+ * + *

This ensures that memory usage scales linearly with the number of shards, not exponentially, + * making it safe to use an unbounded executor queue even with a large number of shards. */ private final ExecutorService sharedShardSubscriptionExecutor; @@ -212,7 +231,7 @@ public void handleSplitsChanges(SplitsChange splitsChanges) { @Override public void close() throws Exception { - // Shutdown the executor service + // Shutdown the executor service first if (sharedShardSubscriptionExecutor != null) { sharedShardSubscriptionExecutor.shutdown(); try { @@ -225,6 +244,34 @@ public void close() throws Exception { } } - asyncStreamProxy.close(); + // Create a separate single-threaded executor for closing the asyncStreamProxy + ExecutorService closeExecutor = new ThreadPoolExecutor( + 1, 1, + 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue<>(), + new ExecutorThreadFactory("kinesis-client-close")); + + try { + // Submit the close task to the executor and wait with a timeout + Future closeFuture = closeExecutor.submit(() -> { + try { + asyncStreamProxy.close(); + } catch (IOException e) { + LOG.warn("Error closing async stream proxy", e); + } + }); + + // Wait for the close operation to complete with a timeout + try { + closeFuture.get(30, TimeUnit.SECONDS); + } catch (TimeoutException e) { + LOG.warn("Timed out while closing async stream proxy", e); + } catch (ExecutionException e) { + LOG.warn("Error while closing async stream proxy", e.getCause()); + } + } finally { + // Ensure the close executor is shut down + closeExecutor.shutdownNow(); + } } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java index 7d5ba7a1..f041a75c 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java @@ -19,6 +19,7 @@ package org.apache.flink.connector.kinesis.source.reader.fanout; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.connector.kinesis.source.exception.KinesisStreamsSourceException; import org.apache.flink.connector.kinesis.source.proxy.AsyncStreamProxy; import org.apache.flink.connector.kinesis.source.split.StartingPosition; @@ -81,17 +82,7 @@ public class FanOutKinesisShardSubscription { /** Executor service to run subscription event processing tasks. */ private final ExecutorService subscriptionEventProcessingExecutor; - /** - * Lock to ensure sequential processing of subscription events for this shard. - * This lock guarantees that for each shard: - * 1. Only one event is processed at a time - * 2. Events are processed in the order they are received - * 3. The critical operations (queue.put, startingPosition update, requestRecords) are executed atomically - * - *

This is essential to prevent race conditions that could lead to data loss or incorrect - * continuation sequence numbers being used after failover. - */ - private final Object subscriptionEventProcessingLock = new Object(); + // Lock removed as we're using method-level synchronization instead // Queue is meant for eager retrieval of records from the Kinesis stream. We will always have 2 // record batches available on next read. @@ -163,10 +154,21 @@ public void activateSubscription() { // We have to use our own CountDownLatch to wait for subscription to be acquired because // subscription event is tracked via the handler. CountDownLatch waitForSubscriptionLatch = new CountDownLatch(1); - shardSubscriber = new FanOutShardSubscriber(waitForSubscriptionLatch); + + // Create a local variable for the new subscriber to prevent a potential race condition + // where the shardSubscriber field might be modified by another thread between when we + // create the lambda and when it's executed. By using a local variable that is captured + // by the lambda, we ensure that the lambda always uses the subscriber instance that was + // created in this method call, regardless of any concurrent modifications to the + // shardSubscriber field. + FanOutShardSubscriber newSubscriber = new FanOutShardSubscriber(waitForSubscriptionLatch); + shardSubscriber = newSubscriber; + SubscribeToShardResponseHandler responseHandler = SubscribeToShardResponseHandler.builder() - .subscriber(() -> shardSubscriber) + // Use the local variable in the lambda to ensure we're always using the + // subscriber instance created in this method call + .subscriber(() -> newSubscriber) .onError( throwable -> { // Errors that occur when obtaining a subscription are thrown @@ -239,10 +241,16 @@ public void activateSubscription() { private void terminateSubscription(Throwable t) { if (!subscriptionException.compareAndSet(null, t)) { LOG.warn( - "Another subscription exception has been queued, ignoring subsequent exceptions", + "Another subscription exception has been queued for shard {}, ignoring subsequent exceptions", + shardId, t); } - shardSubscriber.cancel(); + + if (shardSubscriber != null) { + shardSubscriber.cancel(); + } else { + LOG.warn("Cannot terminate subscription - shardSubscriber is null for shard {}", shardId); + } } /** @@ -273,15 +281,22 @@ public SubscribeToShardEvent nextEvent() { .findFirst(); if (recoverableException.isPresent()) { LOG.warn( - "Recoverable exception encountered while subscribing to shard. Ignoring.", + "Recoverable exception encountered while subscribing to shard {}. Ignoring.", + shardId, recoverableException.get()); - shardSubscriber.cancel(); + + if (shardSubscriber != null) { + shardSubscriber.cancel(); + } else { + LOG.warn("Cannot cancel subscription - shardSubscriber is null for shard {}", shardId); + } + activateSubscription(); return null; } - LOG.error("Subscription encountered unrecoverable exception.", throwable); + LOG.error("Subscription encountered unrecoverable exception for shard {}.", shardId, throwable); throw new KinesisStreamsSourceException( - "Subscription encountered unrecoverable exception.", throwable); + String.format("Subscription encountered unrecoverable exception for shard %s.", shardId), throwable); } if (!subscriptionActive.get()) { @@ -309,17 +324,23 @@ private FanOutShardSubscriber(CountDownLatch subscriptionLatch) { } public void requestRecords() { - subscription.request(1); + if (subscription != null) { + subscription.request(1); + } else { + LOG.warn("Cannot request records - subscription is null for shard {}", shardId); + } } public void cancel() { if (!subscriptionActive.get()) { - LOG.warn("Trying to cancel inactive subscription. Ignoring."); + LOG.warn("Trying to cancel inactive subscription for shard {}. Ignoring.", shardId); return; } subscriptionActive.set(false); if (subscription != null) { subscription.cancel(); + } else { + LOG.debug("Subscription already null during cancellation for shard {}", shardId); } } @@ -378,16 +399,15 @@ public void onComplete() { private void submitEventProcessingTask(SubscribeToShardEvent event) { try { subscriptionEventProcessingExecutor.execute(() -> { - synchronized (subscriptionEventProcessingLock) { - try { - processSubscriptionEvent(event); - } catch (Exception e) { - // For critical path operations, propagate exceptions to cause a Flink job restart - LOG.error("Error processing subscription event", e); - // Propagate the exception to the subscription exception handler - terminateSubscription(new KinesisStreamsSourceException( - "Error processing subscription event", e)); - } + try { + // No synchronized block here, rely on the method-level synchronization + processSubscriptionEvent(event); + } catch (Exception e) { + // For critical path operations, propagate exceptions to cause a Flink job restart + LOG.error("Error processing subscription event", e); + // Propagate the exception to the subscription exception handler + terminateSubscription(new KinesisStreamsSourceException( + "Error processing subscription event", e)); } }); } catch (Exception e) { @@ -411,11 +431,10 @@ private void submitEventProcessingTask(SubscribeToShardEvent event) { * with only requesting more records after processing an event provides natural flow control, * effectively limiting the number of tasks in the executor's queue. * - *

This method is made public for testing purposes. - * * @param event The subscription event to process */ - public void processSubscriptionEvent(SubscribeToShardEvent event) { + @VisibleForTesting + synchronized void processSubscriptionEvent(SubscribeToShardEvent event) { try { if (LOG.isDebugEnabled()) { LOG.debug( @@ -429,13 +448,19 @@ public void processSubscriptionEvent(SubscribeToShardEvent event) { eventQueue.put(event); // Update the starting position to ensure we can recover after failover - // Note: We don't need additional synchronization here because this method is already - // called within a synchronized block on subscriptionEventProcessingLock - startingPosition = StartingPosition.continueFromSequenceNumber( - event.continuationSequenceNumber()); + if (event.continuationSequenceNumber() != null) { + startingPosition = StartingPosition.continueFromSequenceNumber( + event.continuationSequenceNumber()); + } else { + LOG.warn("Received null continuation sequence number for shard {}", shardId); + } // Request more records - shardSubscriber.requestRecords(); + if (shardSubscriber != null) { + shardSubscriber.requestRecords(); + } else { + LOG.warn("Cannot request more records - shardSubscriber is null for shard {}", shardId); + } if (LOG.isDebugEnabled()) { LOG.debug( From 9ad13ea76e6c6d75e356dd9451a6b17858f963f5 Mon Sep 17 00:00:00 2001 From: Nagesh Honnalli Date: Thu, 22 May 2025 12:31:39 -0700 Subject: [PATCH 3/4] [FLINK-34071][Connectors/Kinesis] Refactoring splitReader close code and implementing the right ordering, improving error handling, logging and documentation --- .../fanout/FanOutKinesisShardSplitReader.java | 106 ++++++++++++-- .../FanOutKinesisShardSubscription.java | 133 +++++++++++++----- 2 files changed, 191 insertions(+), 48 deletions(-) diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java index 5d3c46b6..99d9302a 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java @@ -47,6 +47,7 @@ import java.util.concurrent.TimeoutException; import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.EFO_CONSUMER_SUBSCRIPTION_TIMEOUT; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.EFO_DEREGISTER_CONSUMER_TIMEOUT; /** * An implementation of the KinesisShardSplitReader that consumes from Kinesis using Enhanced @@ -59,6 +60,7 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { private final AsyncStreamProxy asyncStreamProxy; private final String consumerArn; private final Duration subscriptionTimeout; + private final Duration deregisterTimeout; /** * Shared executor service for all shard subscriptions. @@ -175,6 +177,7 @@ public FanOutKinesisShardSplitReader( this.asyncStreamProxy = asyncStreamProxy; this.consumerArn = consumerArn; this.subscriptionTimeout = configuration.get(EFO_CONSUMER_SUBSCRIPTION_TIMEOUT); + this.deregisterTimeout = configuration.get(EFO_DEREGISTER_CONSUMER_TIMEOUT); this.subscriptionFactory = subscriptionFactory; this.sharedShardSubscriptionExecutor = executorService; } @@ -229,22 +232,92 @@ public void handleSplitsChanges(SplitsChange splitsChanges) { } } + /** + * Closes the split reader and releases all resources. + * + *

The close method follows a specific order to ensure proper shutdown: + * 1. First, cancel all active subscriptions to prevent new events from being processed + * 2. Then, shutdown the shared executor service to stop processing existing events + * 3. Finally, close the async stream proxy to release network resources + * + *

This ordering is critical because: + * - Cancelling subscriptions first prevents new events from being submitted to the executor + * - Shutting down the executor next ensures all in-flight tasks complete or are cancelled + * - Closing the async stream proxy last ensures all resources are properly released after + * all processing has stopped + */ @Override public void close() throws Exception { - // Shutdown the executor service first - if (sharedShardSubscriptionExecutor != null) { - sharedShardSubscriptionExecutor.shutdown(); - try { - if (!sharedShardSubscriptionExecutor.awaitTermination(10, TimeUnit.SECONDS)) { - sharedShardSubscriptionExecutor.shutdownNow(); + cancelActiveSubscriptions(); + shutdownSharedShardSubscriptionExecutor(); + closeAsyncStreamProxy(); + } + + /** + * Cancels all active subscriptions to prevent new events from being processed. + * + *

After cancelling subscriptions, we wait a short time to allow the cancellation + * signals to propagate before proceeding with executor shutdown. + */ + private void cancelActiveSubscriptions() { + for (FanOutKinesisShardSubscription subscription : splitSubscriptions.values()) { + if (subscription.isActive()) { + try { + subscription.cancelSubscription(); + } catch (Exception e) { + LOG.warn("Error cancelling subscription for shard {}", + subscription.getShardId(), e); } - } catch (InterruptedException e) { + } + } + + // Wait a short time (200ms) to allow cancellation signals to propagate + // This helps ensure that no new tasks are submitted to the executor after we begin shutdown + try { + Thread.sleep(200); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + /** + * Shuts down the shared executor service used for processing subscription events. + * + *

We use the EFO_DEREGISTER_CONSUMER_TIMEOUT (10 seconds) as the shutdown timeout + * to maintain consistency with other deregistration operations in the connector. + */ + private void shutdownSharedShardSubscriptionExecutor() { + if (sharedShardSubscriptionExecutor == null) { + return; + } + + sharedShardSubscriptionExecutor.shutdown(); + try { + // Use the deregister consumer timeout (10 seconds) + // This timeout is consistent with other deregistration operations in the connector + if (!sharedShardSubscriptionExecutor.awaitTermination( + deregisterTimeout.toMillis(), + TimeUnit.MILLISECONDS)) { + LOG.warn("Executor did not terminate in the specified time. Forcing shutdown."); sharedShardSubscriptionExecutor.shutdownNow(); - Thread.currentThread().interrupt(); } + } catch (InterruptedException e) { + LOG.warn("Interrupted while waiting for executor shutdown", e); + sharedShardSubscriptionExecutor.shutdownNow(); + Thread.currentThread().interrupt(); } + } - // Create a separate single-threaded executor for closing the asyncStreamProxy + /** + * Closes the async stream proxy with a timeout. + * + *

We use the EFO_CONSUMER_SUBSCRIPTION_TIMEOUT (60 seconds) as the close timeout + * since closing the client involves similar network operations as subscription. + * The longer timeout accounts for potential network delays during shutdown. + */ + private void closeAsyncStreamProxy() { + // Create a dedicated single-threaded executor for closing the asyncStreamProxy + // This prevents the close operation from being affected by the main executor shutdown ExecutorService closeExecutor = new ThreadPoolExecutor( 1, 1, 0L, TimeUnit.MILLISECONDS, @@ -252,7 +325,7 @@ public void close() throws Exception { new ExecutorThreadFactory("kinesis-client-close")); try { - // Submit the close task to the executor and wait with a timeout + // Submit the close task to the executor to avoid blocking the main thread Future closeFuture = closeExecutor.submit(() -> { try { asyncStreamProxy.close(); @@ -261,16 +334,23 @@ public void close() throws Exception { } }); - // Wait for the close operation to complete with a timeout try { - closeFuture.get(30, TimeUnit.SECONDS); + // Use the subscription timeout (60 seconds) + // This longer timeout is necessary because closing the AWS SDK client + // may involve waiting for in-flight network operations to complete + closeFuture.get( + subscriptionTimeout.toMillis(), + TimeUnit.MILLISECONDS); } catch (TimeoutException e) { LOG.warn("Timed out while closing async stream proxy", e); + } catch (InterruptedException e) { + LOG.warn("Interrupted while closing async stream proxy", e); + Thread.currentThread().interrupt(); } catch (ExecutionException e) { LOG.warn("Error while closing async stream proxy", e.getCause()); } } finally { - // Ensure the close executor is shut down + // Ensure the close executor is always shut down to prevent resource leaks closeExecutor.shutdownNow(); } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java index f041a75c..d4c241a3 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java @@ -112,6 +112,36 @@ public boolean isActive() { return subscriptionActive.get(); } + /** + * Gets the shard ID for this subscription. + * + * @return The shard ID + */ + public String getShardId() { + return shardId; + } + + /** + * Cancels this subscription. + * This is primarily used during shutdown to cancel active subscriptions. + * + * @return true if the subscription was active and was cancelled, false otherwise + */ + public boolean cancelSubscription() { + if (!subscriptionActive.get()) { + LOG.debug("Skipping cancellation of inactive subscription for shard {}.", shardId); + return false; + } + subscriptionActive.set(false); + if (shardSubscriber != null) { + shardSubscriber.cancel(); + return true; + } else { + LOG.warn("Cannot cancel subscription - shardSubscriber is null for shard {}", shardId); + return false; + } + } + private FanOutShardSubscriber shardSubscriber; /** @@ -246,11 +276,7 @@ private void terminateSubscription(Throwable t) { t); } - if (shardSubscriber != null) { - shardSubscriber.cancel(); - } else { - LOG.warn("Cannot terminate subscription - shardSubscriber is null for shard {}", shardId); - } + cancelSubscription(); } /** @@ -285,12 +311,7 @@ public SubscribeToShardEvent nextEvent() { shardId, recoverableException.get()); - if (shardSubscriber != null) { - shardSubscriber.cancel(); - } else { - LOG.warn("Cannot cancel subscription - shardSubscriber is null for shard {}", shardId); - } - + cancelSubscription(); activateSubscription(); return null; } @@ -332,10 +353,8 @@ public void requestRecords() { } public void cancel() { - if (!subscriptionActive.get()) { - LOG.warn("Trying to cancel inactive subscription for shard {}. Ignoring.", shardId); - return; - } + // Set subscription inactive - this is now handled in cancelSubscription() + // but we keep it here as well for safety subscriptionActive.set(false); if (subscription != null) { subscription.cancel(); @@ -390,6 +409,18 @@ public void onComplete() { } } + /** + * Helper method to determine if shutdown is in progress. + * + * @return true if shutdown is in progress, false otherwise + */ + private boolean isShutdownInProgress() { + // Check if the executor service is shutting down or terminated + // This is the most reliable way to detect if shutdown has been initiated + return subscriptionEventProcessingExecutor.isShutdown() || + subscriptionEventProcessingExecutor.isTerminated(); + } + /** * Submits an event processing task to the executor service. * This method encapsulates the task submission logic and error handling. @@ -398,16 +429,27 @@ public void onComplete() { */ private void submitEventProcessingTask(SubscribeToShardEvent event) { try { + // Check if shutdown is in progress before submitting new tasks + // This prevents tasks from being submitted to a shutting down executor + if (isShutdownInProgress()) { + LOG.info("Shutdown in progress, not submitting new event processing task for shard {}", shardId); + return; + } + subscriptionEventProcessingExecutor.execute(() -> { try { - // No synchronized block here, rely on the method-level synchronization + // Process the event processSubscriptionEvent(event); } catch (Exception e) { - // For critical path operations, propagate exceptions to cause a Flink job restart - LOG.error("Error processing subscription event", e); - // Propagate the exception to the subscription exception handler - terminateSubscription(new KinesisStreamsSourceException( - "Error processing subscription event", e)); + // Only log as error if we're not in shutdown mode + if (!isShutdownInProgress()) { + LOG.error("Error processing subscription event", e); + // Propagate the exception to the subscription exception handler + terminateSubscription(new KinesisStreamsSourceException( + "Error processing subscription event", e)); + } else { + LOG.info("Error during shutdown while processing event for shard {} - ignoring", shardId, e); + } } }); } catch (Exception e) { @@ -427,21 +469,35 @@ private void submitEventProcessingTask(SubscribeToShardEvent event) { * 3. Requesting more records * *

These operations are executed sequentially for each shard to ensure thread safety - * and prevent race conditions. The bounded nature of the event queue (capacity 2) combined - * with only requesting more records after processing an event provides natural flow control, - * effectively limiting the number of tasks in the executor's queue. + * and prevent race conditions. * * @param event The subscription event to process */ @VisibleForTesting synchronized void processSubscriptionEvent(SubscribeToShardEvent event) { + // Check if the thread is interrupted before doing any work + // This prevents unnecessary processing during shutdown + if (Thread.currentThread().isInterrupted()) { + // During normal operation, an interruption is unexpected and should be treated as an error + // During shutdown, it's expected and can be handled gracefully + if (!isShutdownInProgress()) { + LOG.error("Thread interrupted unexpectedly before processing event for shard {}", shardId); + throw new KinesisStreamsSourceException( + "Thread interrupted unexpectedly before processing event for shard " + shardId, + new InterruptedException()); + } else { + LOG.info("Thread interrupted during shutdown before processing event for shard {} - skipping processing", shardId); + return; + } + } + try { if (LOG.isDebugEnabled()) { LOG.debug( - "Processing event for shard {}: {}, {}", - shardId, - event.getClass().getSimpleName(), - event); + "Processing event for shard {}: {}, {}", + shardId, + event.getClass().getSimpleName(), + event); } // Put event in queue - this is a blocking operation @@ -450,7 +506,7 @@ synchronized void processSubscriptionEvent(SubscribeToShardEvent event) { // Update the starting position to ensure we can recover after failover if (event.continuationSequenceNumber() != null) { startingPosition = StartingPosition.continueFromSequenceNumber( - event.continuationSequenceNumber()); + event.continuationSequenceNumber()); } else { LOG.warn("Received null continuation sequence number for shard {}", shardId); } @@ -464,16 +520,23 @@ synchronized void processSubscriptionEvent(SubscribeToShardEvent event) { if (LOG.isDebugEnabled()) { LOG.debug( - "Successfully processed event for shard {}, updated position to {}", - shardId, - startingPosition); + "Successfully processed event for shard {}, updated position to {}", + shardId, + startingPosition); } } catch (InterruptedException e) { + // Log that we're handling an interruption during shutdown + LOG.info("Interrupted while adding Kinesis record to internal buffer for shard {} - this is expected during shutdown", shardId); + + // Restore the interrupt status Thread.currentThread().interrupt(); - // Consistent with current implementation - throw KinesisStreamsSourceException - throw new KinesisStreamsSourceException( + + // During shutdown, we don't want to throw an exception that would cause a job failure + // Only throw if we're not in a shutdown context + if (!isShutdownInProgress()) { + throw new KinesisStreamsSourceException( "Interrupted while adding Kinesis record to internal buffer.", e); + } } - // No catch for other exceptions - let them propagate to be handled by the AWS SDK } } From 3fd9ce788e2fa1e6a2c874b1392f8f468e1d2c60 Mon Sep 17 00:00:00 2001 From: Nagesh Honnalli Date: Sun, 25 May 2025 23:12:46 -0700 Subject: [PATCH 4/4] [FLINK-34071][Connectors/Kinesis] Handing off shard subscription to a separate shared executor and update to tests --- .../fanout/FanOutKinesisShardSplitReader.java | 114 ++++++++++++++++-- .../FanOutKinesisShardSubscription.java | 72 +++++++---- .../FanOutKinesisShardHappyPathTest.java | 8 ++ .../FanOutKinesisShardRecordOrderingTest.java | 4 + .../fanout/FanOutKinesisShardRestartTest.java | 12 ++ .../FanOutKinesisShardSplitReaderTest.java | 11 ++ ...KinesisShardSplitReaderThreadPoolTest.java | 72 ++++++++--- ...esisShardSubscriptionThreadSafetyTest.java | 2 +- .../fanout/FanOutKinesisShardTestBase.java | 10 +- 9 files changed, 249 insertions(+), 56 deletions(-) diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java index 99d9302a..d1af8b96 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReader.java @@ -91,6 +91,27 @@ public class FanOutKinesisShardSplitReader extends KinesisShardSplitReaderBase { */ private final ExecutorService sharedShardSubscriptionExecutor; + /** + * Shared executor service for making subscribeToShard API calls. + * + *

This executor is separate from the event processing executor to avoid contention + * between API calls and event processing. Using a dedicated executor for subscription calls + * provides several important benefits: + * + *

    + *
  1. Prevents blocking of the main thread or event processing threads during API calls
  2. + *
  3. Isolates API call failures from event processing operations
  4. + *
  5. Allows for controlled concurrency of API calls across multiple shards
  6. + *
  7. Prevents potential deadlocks that could occur when the same thread handles both + * subscription calls and event processing
  8. + *
+ * + *

The executor uses a smaller number of threads than the event processing executor since + * subscription calls are less frequent and primarily I/O bound. This helps optimize resource + * usage while still providing sufficient parallelism for multiple concurrent subscription calls. + */ + private final ExecutorService sharedSubscriptionCallExecutor; + private final Map splitSubscriptions = new HashMap<>(); /** @@ -104,7 +125,8 @@ FanOutKinesisShardSubscription createSubscription( String shardId, StartingPosition startingPosition, Duration timeout, - ExecutorService executor); + ExecutorService eventProcessingExecutor, + ExecutorService subscriptionCallExecutor); } /** @@ -118,14 +140,16 @@ public FanOutKinesisShardSubscription createSubscription( String shardId, StartingPosition startingPosition, Duration timeout, - ExecutorService executor) { + ExecutorService eventProcessingExecutor, + ExecutorService subscriptionCallExecutor) { return new FanOutKinesisShardSubscription( proxy, consumerArn, shardId, startingPosition, timeout, - executor); + eventProcessingExecutor, + subscriptionCallExecutor); } } @@ -152,18 +176,20 @@ public FanOutKinesisShardSplitReader( shardMetricGroupMap, configuration, subscriptionFactory, - createDefaultExecutor()); + createDefaultEventProcessingExecutor(), + createDefaultSubscriptionCallExecutor()); } /** - * Constructor with injected executor service for testing. + * Constructor with injected executor services for testing. * * @param asyncStreamProxy The proxy for Kinesis API calls * @param consumerArn The ARN of the consumer * @param shardMetricGroupMap The metrics map * @param configuration The configuration * @param subscriptionFactory The factory for creating subscriptions - * @param executorService The executor service to use for subscription tasks + * @param eventProcessingExecutor The executor service to use for event processing tasks + * @param subscriptionCallExecutor The executor service to use for subscription API calls */ @VisibleForTesting FanOutKinesisShardSplitReader( @@ -172,22 +198,24 @@ public FanOutKinesisShardSplitReader( Map shardMetricGroupMap, Configuration configuration, SubscriptionFactory subscriptionFactory, - ExecutorService executorService) { + ExecutorService eventProcessingExecutor, + ExecutorService subscriptionCallExecutor) { super(shardMetricGroupMap, configuration); this.asyncStreamProxy = asyncStreamProxy; this.consumerArn = consumerArn; this.subscriptionTimeout = configuration.get(EFO_CONSUMER_SUBSCRIPTION_TIMEOUT); this.deregisterTimeout = configuration.get(EFO_DEREGISTER_CONSUMER_TIMEOUT); this.subscriptionFactory = subscriptionFactory; - this.sharedShardSubscriptionExecutor = executorService; + this.sharedShardSubscriptionExecutor = eventProcessingExecutor; + this.sharedSubscriptionCallExecutor = subscriptionCallExecutor; } /** - * Creates the default executor service for subscription tasks. + * Creates the default executor service for event processing tasks. * * @return A new executor service */ - private static ExecutorService createDefaultExecutor() { + private static ExecutorService createDefaultEventProcessingExecutor() { int minThreads = Runtime.getRuntime().availableProcessors(); int maxThreads = minThreads * 2; return new ThreadPoolExecutor( @@ -198,6 +226,37 @@ private static ExecutorService createDefaultExecutor() { new ExecutorThreadFactory("kinesis-efo-subscription")); } + /** + * Creates the default executor service for subscription API calls. + * + *

This executor is configured with: + *

+ * + *

This configuration balances resource efficiency with responsiveness for subscription calls. + * Since subscription calls are primarily waiting on network I/O, a relatively small number of + * threads can efficiently handle many concurrent calls. + * + * @return A new executor service optimized for subscription API calls + */ + private static ExecutorService createDefaultSubscriptionCallExecutor() { + int minThreads = 1; + int maxThreads = Math.max(2, Runtime.getRuntime().availableProcessors() / 4); + return new ThreadPoolExecutor( + minThreads, + maxThreads, + 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(), // Unbounded queue with natural flow control + new ExecutorThreadFactory("kinesis-subscription-caller")); + } + @Override protected RecordBatch fetchRecords(KinesisShardSplitState splitState) { FanOutKinesisShardSubscription subscription = @@ -226,7 +285,8 @@ public void handleSplitsChanges(SplitsChange splitsChanges) { split.getShardId(), split.getStartingPosition(), subscriptionTimeout, - sharedShardSubscriptionExecutor); + sharedShardSubscriptionExecutor, + sharedSubscriptionCallExecutor); subscription.activateSubscription(); splitSubscriptions.put(split.splitId(), subscription); } @@ -250,6 +310,7 @@ public void handleSplitsChanges(SplitsChange splitsChanges) { public void close() throws Exception { cancelActiveSubscriptions(); shutdownSharedShardSubscriptionExecutor(); + shutdownSharedSubscriptionCallExecutor(); closeAsyncStreamProxy(); } @@ -298,16 +359,43 @@ private void shutdownSharedShardSubscriptionExecutor() { if (!sharedShardSubscriptionExecutor.awaitTermination( deregisterTimeout.toMillis(), TimeUnit.MILLISECONDS)) { - LOG.warn("Executor did not terminate in the specified time. Forcing shutdown."); + LOG.warn("Event processing executor did not terminate in the specified time. Forcing shutdown."); sharedShardSubscriptionExecutor.shutdownNow(); } } catch (InterruptedException e) { - LOG.warn("Interrupted while waiting for executor shutdown", e); + LOG.warn("Interrupted while waiting for event processing executor shutdown", e); sharedShardSubscriptionExecutor.shutdownNow(); Thread.currentThread().interrupt(); } } + /** + * Shuts down the shared executor service used for subscription API calls. + * + *

We use the EFO_DEREGISTER_CONSUMER_TIMEOUT (10 seconds) as the shutdown timeout + * to maintain consistency with other deregistration operations in the connector. + */ + private void shutdownSharedSubscriptionCallExecutor() { + if (sharedSubscriptionCallExecutor == null) { + return; + } + + sharedSubscriptionCallExecutor.shutdown(); + try { + // Use a shorter timeout since these are just API calls + if (!sharedSubscriptionCallExecutor.awaitTermination( + deregisterTimeout.toMillis(), + TimeUnit.MILLISECONDS)) { + LOG.warn("Subscription call executor did not terminate in the specified time. Forcing shutdown."); + sharedSubscriptionCallExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + LOG.warn("Interrupted while waiting for subscription call executor shutdown", e); + sharedSubscriptionCallExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + /** * Closes the async stream proxy with a timeout. * diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java index d4c241a3..6da47ead 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscription.java @@ -82,6 +82,9 @@ public class FanOutKinesisShardSubscription { /** Executor service to run subscription event processing tasks. */ private final ExecutorService subscriptionEventProcessingExecutor; + /** Executor service to run subscription API calls. */ + private final ExecutorService subscriptionCallExecutor; + // Lock removed as we're using method-level synchronization instead // Queue is meant for eager retrieval of records from the Kinesis stream. We will always have 2 @@ -153,6 +156,7 @@ public boolean cancelSubscription() { * @param startingPosition The starting position for the subscription * @param subscriptionTimeout The timeout for the subscription * @param subscriptionEventProcessingExecutor The executor service to use for processing subscription events + * @param subscriptionCallExecutor The executor service to use for making subscription API calls */ public FanOutKinesisShardSubscription( AsyncStreamProxy kinesis, @@ -160,13 +164,15 @@ public FanOutKinesisShardSubscription( String shardId, StartingPosition startingPosition, Duration subscriptionTimeout, - ExecutorService subscriptionEventProcessingExecutor) { + ExecutorService subscriptionEventProcessingExecutor, + ExecutorService subscriptionCallExecutor) { this.kinesis = kinesis; this.consumerArn = consumerArn; this.shardId = shardId; this.startingPosition = startingPosition; this.subscriptionTimeout = subscriptionTimeout; this.subscriptionEventProcessingExecutor = subscriptionEventProcessingExecutor; + this.subscriptionCallExecutor = subscriptionCallExecutor; } /** Method to allow eager activation of the subscription. */ @@ -211,27 +217,49 @@ public void activateSubscription() { }) .build(); - // We don't need to keep track of the future here because we monitor subscription success - // using our own CountDownLatch - kinesis.subscribeToShard(consumerArn, shardId, startingPosition, responseHandler) - .exceptionally( - throwable -> { - // If consumer exists and is still activating, we want to countdown. - if (ExceptionUtils.findThrowable( - throwable, ResourceInUseException.class) - .isPresent()) { - waitForSubscriptionLatch.countDown(); - return null; - } - LOG.error( - "Error subscribing to shard {} with starting position {} for consumer {}.", - shardId, - startingPosition, - consumerArn, - throwable); - terminateSubscription(throwable); - return null; - }); + // Use the executor service to make the subscription call + // This offloads the potentially blocking API call to a dedicated thread pool, + // preventing it from blocking the main thread or the event processing threads. + // This separation is crucial to avoid potential deadlocks that could occur when + // the Netty event loop thread (used by the AWS SDK) needs to handle both the + // subscription call and the resulting events. + CompletableFuture subscriptionFuture = CompletableFuture.supplyAsync( + () -> { + try { + LOG.debug("Making subscribeToShard API call for shard {} on thread {}", + shardId, Thread.currentThread().getName()); + + // Make the API call using the provided executor + return kinesis.subscribeToShard(consumerArn, shardId, startingPosition, responseHandler); + } catch (Exception e) { + // Handle any exceptions that occur during the API call + LOG.error("Exception during subscribeToShard API call for shard {}", shardId, e); + terminateSubscription(e); + waitForSubscriptionLatch.countDown(); + return CompletableFuture.completedFuture(null); + } + }, + subscriptionCallExecutor + ).thenCompose(future -> future); // Flatten the CompletableFuture> to CompletableFuture + + subscriptionFuture.exceptionally( + throwable -> { + // If consumer exists and is still activating, we want to countdown. + if (ExceptionUtils.findThrowable( + throwable, ResourceInUseException.class) + .isPresent()) { + waitForSubscriptionLatch.countDown(); + return null; + } + LOG.error( + "Error subscribing to shard {} with starting position {} for consumer {}.", + shardId, + startingPosition, + consumerArn, + throwable); + terminateSubscription(throwable); + return null; + }); // We have to handle timeout for subscriptions separately because Java 8 does not support a // fluent orTimeout() methods on CompletableFuture. diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java index fc18c2c8..0664c820 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardHappyPathTest.java @@ -79,11 +79,15 @@ public void testBasicHappyPathSingleShard() throws Exception { metricsMap, configuration, createTestSubscriptionFactory(), + testExecutor, testExecutor); // Add a split to the reader reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + // Trigger the executor to execute the subscription tasks + testExecutor.triggerAll(); + // Verify that the subscription was activated ArgumentCaptor shardIdCaptor = ArgumentCaptor.forClass(String.class); ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class); @@ -132,6 +136,7 @@ public void testBasicHappyPathMultipleShards() throws Exception { metricsMap, configuration, createTestSubscriptionFactory(), + testExecutor, testExecutor); // Add splits to the reader @@ -140,6 +145,9 @@ public void testBasicHappyPathMultipleShards() throws Exception { splits.add(split2); reader.handleSplitsChanges(new SplitsAddition<>(splits)); + // Trigger the executor to execute the subscription tasks + testExecutor.triggerAll(); + // Verify that subscriptions were activated for both shards ArgumentCaptor shardIdCaptor = ArgumentCaptor.forClass(String.class); ArgumentCaptor startingPositionCaptor = ArgumentCaptor.forClass(StartingPosition.class); diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java index 9c3fdf69..d55b4d3a 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRecordOrderingTest.java @@ -258,6 +258,7 @@ public CompletableFuture answer(InvocationOnMock invocation) { Collections.emptyMap(), configuration, createTestSubscriptionFactory(), + testExecutor, testExecutor); // Add a split to the reader @@ -268,6 +269,9 @@ public CompletableFuture answer(InvocationOnMock invocation) { customReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + // Trigger the executor to execute the subscription tasks + testExecutor.triggerAll(); + // Create test events with records int numEvents = 5; int recordsPerEvent = 10; diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java index 2cd7cb7b..8bc65535 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardRestartTest.java @@ -90,11 +90,15 @@ public void testRestartUsesCorrectStartingPosition() throws Exception { metricsMap, configuration, createTestSubscriptionFactory(), + testExecutor, testExecutor); // Add a split to the reader reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + // Trigger the executor to execute the subscription tasks + testExecutor.triggerAll(); + // Verify that the subscription was activated with the initial starting position verify(customProxy, times(1)).subscribeToShard( eq(CONSUMER_ARN), @@ -123,11 +127,15 @@ public void testRestartUsesCorrectStartingPosition() throws Exception { metricsMap, restartConfiguration, createTestSubscriptionFactory(), + testExecutor, testExecutor); // Add the updated split to the restarted reader restartedReader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(updatedSplit))); + // Trigger the executor to execute the subscription tasks for the restarted reader + testExecutor.triggerAll(); + // Verify that the subscription was reactivated with the updated starting position verify(customProxy, times(2)).subscribeToShard( eq(CONSUMER_ARN), @@ -194,11 +202,15 @@ private void testExceptionHandling(Exception exception, boolean isRecoverable) t metricsMap, configuration, createTestSubscriptionFactory(), + testExecutor, testExecutor); // Add a split to the reader reader.handleSplitsChanges(new SplitsAddition<>(Collections.singletonList(split))); + // Trigger the executor to execute the subscription tasks + testExecutor.triggerAll(); + // If the exception is recoverable, the reader should try to reactivate the subscription // If not, it should propagate the exception if (isRecoverable) { diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java index 8a00d192..eec78db4 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderTest.java @@ -84,6 +84,7 @@ public void testNoAssignedSplitsHandledGracefully() throws Exception { shardMetricGroupMap, newConfigurationForTest(), createTestSubscriptionFactory(), + testExecutor, testExecutor); RecordsWithSplitIds retrievedRecords = splitReader.fetch(); @@ -103,10 +104,14 @@ public void testAssignedSplitHasNoRecordsHandledGracefully() throws Exception { shardMetricGroupMap, newConfigurationForTest(), createTestSubscriptionFactory(), + testExecutor, testExecutor); splitReader.handleSplitsChanges( new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID)))); + // Trigger the executor to execute the subscription tasks + testExecutor.triggerAll(); + // When fetching records RecordsWithSplitIds retrievedRecords = splitReader.fetch(); @@ -128,10 +133,14 @@ public void testSplitWithExpiredShardHandledAsCompleted() throws Exception { shardMetricGroupMap, newConfigurationForTest(), createTestSubscriptionFactory(), + testExecutor, testExecutor); splitReader.handleSplitsChanges( new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID)))); + // Trigger the executor to execute the subscription tasks + testExecutor.triggerAll(); + // When fetching records RecordsWithSplitIds retrievedRecords = splitReader.fetch(); @@ -151,6 +160,7 @@ public void testWakeUpIsNoOp() { shardMetricGroupMap, newConfigurationForTest(), createTestSubscriptionFactory(), + testExecutor, testExecutor); // When wakeup is called @@ -170,6 +180,7 @@ public void testCloseClosesStreamProxy() throws Exception { shardMetricGroupMap, newConfigurationForTest(), createTestSubscriptionFactory(), + testExecutor, testExecutor); // When split reader is not closed diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java index d6ee85fb..7374cdde 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSplitReaderThreadPoolTest.java @@ -42,6 +42,7 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicInteger; @@ -89,9 +90,9 @@ public void testThreadPoolProcessesMultipleShards() throws Exception { // Create a custom subscription factory that creates test subscriptions FanOutKinesisShardSplitReader.SubscriptionFactory customFactory = - (proxy, consumerArn, shardId, startingPosition, timeout, executor) -> { + (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) -> { TestSubscription subscription = new TestSubscription( - proxy, consumerArn, shardId, startingPosition, timeout, executor, + proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor, processedEvents, expectedEvents); testSubscriptions.put(shardId, subscription); return subscription; @@ -124,6 +125,7 @@ public void testThreadPoolProcessesMultipleShards() throws Exception { metricsMap, configuration, customFactory, + testExecutor, testExecutor); // Add multiple splits to the reader @@ -200,10 +202,13 @@ public void testThreadPoolFlowControl() throws Exception { customProxy, CONSUMER_ARN, metricsMap, - configuration); + configuration, + createTestSubscriptionFactory(), + Executors.newCachedThreadPool(), + Executors.newCachedThreadPool()); - // Get access to the executor service - ExecutorService executor = getExecutorService(splitReader); + // Get access to the event processing executor service + ExecutorService executor = getEventProcessingExecutorService(splitReader); assertThat(executor).isInstanceOf(ThreadPoolExecutor.class); ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executor; @@ -224,9 +229,9 @@ public void testThreadPoolFlowControl() throws Exception { // Create a custom subscription factory that adds artificial delay FanOutKinesisShardSplitReader.SubscriptionFactory customFactory = - (proxy, consumerArn, shardId, startingPosition, timeout, executorService) -> { + (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) -> { return new FanOutKinesisShardSubscription( - proxy, consumerArn, shardId, startingPosition, timeout, executorService) { + proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) { @Override public void processSubscriptionEvent(SubscribeToShardEvent event) { try { @@ -312,17 +317,23 @@ public void testThreadPoolShutdown() throws Exception { mockAsyncStreamProxy, CONSUMER_ARN, metricsMap, - configuration); + configuration, + createTestSubscriptionFactory(), + Executors.newCachedThreadPool(), + Executors.newCachedThreadPool()); - // Get access to the executor service - ExecutorService executor = getExecutorService(splitReader); - assertThat(executor).isNotNull(); + // Get access to the executor services + ExecutorService eventProcessingExecutor = getEventProcessingExecutorService(splitReader); + ExecutorService subscriptionCallExecutor = getSubscriptionCallExecutorService(splitReader); + assertThat(eventProcessingExecutor).isNotNull(); + assertThat(subscriptionCallExecutor).isNotNull(); // Close the split reader splitReader.close(); - // Verify that the executor service is shut down - assertThat(executor.isShutdown()).isTrue(); + // Verify that both executor services are shut down + assertThat(eventProcessingExecutor.isShutdown()).isTrue(); + assertThat(subscriptionCallExecutor.isShutdown()).isTrue(); } /** @@ -337,14 +348,23 @@ private SubscribeToShardEvent createTestEvent(String continuationSequenceNumber) } /** - * Gets the executor service from the split reader using reflection. + * Gets the event processing executor service from the split reader using reflection. */ - private ExecutorService getExecutorService(FanOutKinesisShardSplitReader splitReader) throws Exception { + private ExecutorService getEventProcessingExecutorService(FanOutKinesisShardSplitReader splitReader) throws Exception { Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("sharedShardSubscriptionExecutor"); field.setAccessible(true); return (ExecutorService) field.get(splitReader); } + /** + * Gets the subscription call executor service from the split reader using reflection. + */ + private ExecutorService getSubscriptionCallExecutorService(FanOutKinesisShardSplitReader splitReader) throws Exception { + Field field = FanOutKinesisShardSplitReader.class.getDeclaredField("sharedSubscriptionCallExecutor"); + field.setAccessible(true); + return (ExecutorService) field.get(splitReader); + } + /** * Sets the subscription factory in the split reader using reflection. */ @@ -356,6 +376,23 @@ private void setSubscriptionFactory( field.set(splitReader, factory); } + /** + * Creates a test subscription factory. + * + * @return A test subscription factory + */ + private FanOutKinesisShardSplitReader.SubscriptionFactory createTestSubscriptionFactory() { + return (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) -> + new FanOutKinesisShardSubscription( + proxy, + consumerArn, + shardId, + startingPosition, + timeout, + eventProcessingExecutor, + subscriptionCallExecutor); + } + /** * A test subscription that ensures we process exactly EVENTS_PER_SHARD events per shard. */ @@ -371,10 +408,11 @@ public TestSubscription( String shardId, StartingPosition startingPosition, Duration timeout, - ExecutorService executor, + ExecutorService eventProcessingExecutor, + ExecutorService subscriptionCallExecutor, AtomicInteger globalCounter, int expectedTotal) { - super(proxy, consumerArn, shardId, startingPosition, timeout, executor); + super(proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor); this.shardId = shardId; this.globalCounter = globalCounter; this.expectedTotal = expectedTotal; diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java index 1e2a32f6..102c184d 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardSubscriptionThreadSafetyTest.java @@ -362,7 +362,7 @@ public TestableSubscription( Duration subscriptionTimeout, ExecutorService subscriptionEventProcessingExecutor, BlockingQueue testEventQueue) { - super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor); + super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor, subscriptionEventProcessingExecutor); this.testEventQueue = testEventQueue; this.currentStartingPosition = startingPosition; } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java index bdd8bdd7..8e2fc356 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/fanout/FanOutKinesisShardTestBase.java @@ -82,8 +82,9 @@ public TestableSubscription( StartingPosition startingPosition, Duration subscriptionTimeout, ExecutorService subscriptionEventProcessingExecutor, + ExecutorService subscriptionCallExecutor, BlockingQueue recordQueue) { - super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor); + super(kinesis, consumerArn, shardId, startingPosition, subscriptionTimeout, subscriptionEventProcessingExecutor, subscriptionCallExecutor); this.recordQueue = recordQueue; this.currentStartingPosition = startingPosition; } @@ -146,6 +147,7 @@ protected TestableSubscription createTestableSubscription( startingPosition, TEST_SUBSCRIPTION_TIMEOUT, testExecutor, + testExecutor, // Use the same executor for subscription calls recordQueue); } @@ -155,14 +157,15 @@ protected TestableSubscription createTestableSubscription( * @return A test subscription factory */ protected FanOutKinesisShardSplitReader.SubscriptionFactory createTestSubscriptionFactory() { - return (proxy, consumerArn, shardId, startingPosition, timeout, executor) -> + return (proxy, consumerArn, shardId, startingPosition, timeout, eventProcessingExecutor, subscriptionCallExecutor) -> new FanOutKinesisShardSubscription( proxy, consumerArn, shardId, startingPosition, timeout, - executor); + eventProcessingExecutor, + subscriptionCallExecutor); } /** @@ -182,6 +185,7 @@ protected FanOutKinesisShardSplitReader createSplitReaderWithShard(String shardI Mockito.mock(java.util.Map.class), configuration, createTestSubscriptionFactory(), + testExecutor, testExecutor); // Create a split