diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java index 962611314..d12483848 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/lifecycle/ShardConsumer.java @@ -179,6 +179,10 @@ public void executeLifecycle() { // Task rejection during the subscribe() call will not be propagated back as it not executed // in the context of the Scheduler thread. Hence we should not assume the subscription will // always be successful. + // But if subscription was not successful, then it will recover + // during healthCheck which will restart subscription. + // From Shardconsumer point of view, initialization after the below subscribe call + // is complete subscribe(); needsInitialization = false; } @@ -276,6 +280,16 @@ void subscribe() { @VisibleForTesting synchronized CompletableFuture initializeComplete() { + if (!needsInitialization) { + // initialization already complete, this must be a no-op. + // ShardConsumer must be in ProcessingState and + // any further activity will be driven by publisher pushing data to subscriber + // which invokes handleInput and that triggers ProcessTask. + // Scheduler is only meant to do health-checks to ensure the consumer + // is not stuck for any reason and to do shutdown handling. + return CompletableFuture.completedFuture(true); + } + if (taskOutcome != null) { updateState(taskOutcome); } diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java index 42f88b12f..8db3d5172 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java @@ -32,7 +32,9 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -45,6 +47,7 @@ import java.util.Optional; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -53,6 +56,7 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import org.junit.After; @@ -62,7 +66,9 @@ import org.junit.Test; import org.junit.rules.TestName; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.reactivestreams.Subscriber; @@ -148,6 +154,7 @@ public class ShardConsumerTest { @Before public void before() { + MockitoAnnotations.initMocks(this); shardInfo = new ShardInfo(shardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON); ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("test-" + testName.getMethodName() + "-%04d") .setDaemon(true).build(); @@ -848,6 +855,114 @@ public void testLongRunningTasks() throws Exception { verifyNoMoreInteractions(taskExecutionListener); } + @Test + public void testEmptyShardProcessingRaceCondition() throws Exception { + final RecordsPublisher mockPublisher = mock(RecordsPublisher.class); + final ExecutorService mockExecutor = mock(ExecutorService.class); + final ConsumerState mockState = mock(ConsumerState.class); + final ShardConsumer consumer = new ShardConsumer(mockPublisher, mockExecutor, shardInfo, Optional.of(1L), + shardConsumerArgument, mockState, Function.identity(), 1, taskExecutionListener, 0); + + when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS); + when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); + final ConsumerTask mockTask = mock(ConsumerTask.class); + when(mockState.createTask(any(), any(), any())).thenReturn(mockTask); + // Simulate successful BlockedOnParent task execution + // and successful Initialize task execution + when(mockTask.call()).thenReturn(new TaskResult(false)); + + log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to initiate async" + + " processing of blocked on parent task"); + consumer.executeLifecycle(); + final ArgumentCaptor taskToExecute = ArgumentCaptor.forClass(Runnable.class); + verify(mockExecutor, timeout(100)).execute(taskToExecute.capture()); + taskToExecute.getValue().run(); + log.info("RecordProcessor Thread: Simulated successful execution of Blocked on parent task"); + reset(mockExecutor); + + log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to InitializingState" + + " and initiate async processing of initialize task"); + when(mockState.successTransition()).thenReturn(mockState); + when(mockState.state()).thenReturn(ShardConsumerState.INITIALIZING); + when(mockState.taskType()).thenReturn(TaskType.INITIALIZE); + consumer.executeLifecycle(); + verify(mockExecutor, timeout(100)).execute(taskToExecute.capture()); + log.info("RecordProcessor Thread: Simulated successful execution of Initialize task"); + taskToExecute.getValue().run(); + + log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to ProcessingState" + + " and mark initialization future as complete"); + when(mockState.state()).thenReturn(ShardConsumerState.PROCESSING); + consumer.executeLifecycle(); + + // Simulate the race where + // scheduler invokes executeLifecycle which performs Publisher.subscribe(subscriber) + // on recordProcessor thread + // but before scheduler thread finishes initialization, handleInput is invoked + // on record processor thread. + + // Since ShardConsumer creates its own instance of subscriber that cannot be mocked + // this test sequence will appear a little odd. + // In order to control the order in which execution occurs, lets first invoke + // handleInput, although this will never happen, since there isn't a way + // to control the precise timing of the thread execution, this is the best way + final CountDownLatch processTaskLatch = new CountDownLatch(1); + new Thread(() -> { + reset(mockState); + when(mockState.taskType()).thenReturn(TaskType.PROCESS); + final ConsumerTask mockProcessTask = mock(ConsumerTask.class); + when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask); + when(mockProcessTask.call()).then(input -> { + // first we want to wait for subscribe to be called, + // but we cannot control the timing, so wait for 10 seconds + // to let the main thread invoke executeLifecyle which + // will perform subscribe + processTaskLatch.countDown(); + log.info("Record Processor Thread: Holding shardConsumer lock, waiting for 10 seconds to" + + " let subscribe be called by scheduler thread"); + Thread.sleep(10 * 1000); + log.info("RecordProcessor Thread: Done waiting"); + // then return shard end result + log.info("RecordProcessor Thread: Simulating execution of ProcessTask and returning shard-end result"); + return new TaskResult(true); + }); + final Subscription mockSubscription = mock(Subscription.class); + consumer.handleInput(ProcessRecordsInput.builder().isAtShardEnd(true).build(), mockSubscription); + }).start(); + + processTaskLatch.await(); + + // invoke executeLifecycle, which should invoke subscribe + // meanwhile if scheduler tries to acquire the ShardConsumer lock it will + // be blocked during initialization processing because handleInput was + // already invoked and will be holding the lock. Thereby creating the + // race condition we want. + reset(mockState); + AtomicBoolean successTransitionCalled = new AtomicBoolean(false); + when(mockState.successTransition()).then(input -> { + successTransitionCalled.set(true); + return mockState; + }); + AtomicBoolean shutdownTransitionCalled = new AtomicBoolean(false); + when(mockState.shutdownTransition(any())).then(input -> { + shutdownTransitionCalled.set(true); + return mockState; + }); + when(mockState.state()).then(input -> { + if (successTransitionCalled.get() && shutdownTransitionCalled.get()) { + return ShardConsumerState.SHUTTING_DOWN; + } + return ShardConsumerState.PROCESSING; + }); + log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to invoke subscribe and" + + " complete initialization"); + consumer.executeLifecycle(); + log.info("Scheduler Thread: Done initializing the ShardConsumer"); + + log.info("Verifying scheduler did not perform shutdown transition during initialization"); + verify(mockState, times(0)).shutdownTransition(any()); + } + private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) { mockSuccessfulShutdown(taskCallBarrier, null); }