Skip to content

Commit

Permalink
Fix a race condition between ShardConsumer shutdown and initialization (
Browse files Browse the repository at this point in the history
#1319)

* Fix a race condition between ShardConsumer shutdown and initialization

When Kinesis shards have no data, there can be a race condition where
the shard-end record processing from RecordProcessorThread
interleaves with Scheduler performing initialization.
This leads to ShardConsumer making incorrect state transition
during initialization (moves from PROCESSING -> SHUTTING_DOWN) state
and during shutdown handling it moves from SHUTTING_DOWN -> SHUTDOWN_COMPLETE
without running the ShutdownTask.

This can cause the ShardConsumer to not perform proper shutdown
processing that is required for a child shard processing
to be unblocked. So the child shard could be blocked forever unless the
lease for the parent shard moves to a new worker and that worker does
not run into the race condition.

This patch fixes the race condition as follows:

The intializationComplete invocation is not needed after
needsInitialization has been set to false. Because initializationComplete
is mean to perform initialization in an async manner, but once
its done, the async task is a no-op in happy-path, but it can
perform incorrect state transition during a race condition.
  • Loading branch information
akidambisrinivasan authored May 2, 2024
1 parent 69cf599 commit 16e8404
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ public void executeLifecycle() {
// Task rejection during the subscribe() call will not be propagated back as it not executed
// in the context of the Scheduler thread. Hence we should not assume the subscription will
// always be successful.
// But if subscription was not successful, then it will recover
// during healthCheck which will restart subscription.
// From Shardconsumer point of view, initialization after the below subscribe call
// is complete
subscribe();
needsInitialization = false;
}
Expand Down Expand Up @@ -276,6 +280,16 @@ void subscribe() {

@VisibleForTesting
synchronized CompletableFuture<Boolean> initializeComplete() {
if (!needsInitialization) {
// initialization already complete, this must be a no-op.
// ShardConsumer must be in ProcessingState and
// any further activity will be driven by publisher pushing data to subscriber
// which invokes handleInput and that triggers ProcessTask.
// Scheduler is only meant to do health-checks to ensure the consumer
// is not stuck for any reason and to do shutdown handling.
return CompletableFuture.completedFuture(true);
}

if (taskOutcome != null) {
updateState(taskOutcome);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand All @@ -45,6 +47,7 @@
import java.util.Optional;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
Expand All @@ -53,6 +56,7 @@
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import org.junit.After;
Expand All @@ -62,7 +66,9 @@
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.reactivestreams.Subscriber;
Expand Down Expand Up @@ -148,6 +154,7 @@ public class ShardConsumerTest {

@Before
public void before() {
MockitoAnnotations.initMocks(this);
shardInfo = new ShardInfo(shardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("test-" + testName.getMethodName() + "-%04d")
.setDaemon(true).build();
Expand Down Expand Up @@ -848,6 +855,114 @@ public void testLongRunningTasks() throws Exception {
verifyNoMoreInteractions(taskExecutionListener);
}

@Test
public void testEmptyShardProcessingRaceCondition() throws Exception {
final RecordsPublisher mockPublisher = mock(RecordsPublisher.class);
final ExecutorService mockExecutor = mock(ExecutorService.class);
final ConsumerState mockState = mock(ConsumerState.class);
final ShardConsumer consumer = new ShardConsumer(mockPublisher, mockExecutor, shardInfo, Optional.of(1L),
shardConsumerArgument, mockState, Function.identity(), 1, taskExecutionListener, 0);

when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
final ConsumerTask mockTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
// Simulate successful BlockedOnParent task execution
// and successful Initialize task execution
when(mockTask.call()).thenReturn(new TaskResult(false));

log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to initiate async" +
" processing of blocked on parent task");
consumer.executeLifecycle();
final ArgumentCaptor<Runnable> taskToExecute = ArgumentCaptor.forClass(Runnable.class);
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
taskToExecute.getValue().run();
log.info("RecordProcessor Thread: Simulated successful execution of Blocked on parent task");
reset(mockExecutor);

log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to InitializingState" +
" and initiate async processing of initialize task");
when(mockState.successTransition()).thenReturn(mockState);
when(mockState.state()).thenReturn(ShardConsumerState.INITIALIZING);
when(mockState.taskType()).thenReturn(TaskType.INITIALIZE);
consumer.executeLifecycle();
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
log.info("RecordProcessor Thread: Simulated successful execution of Initialize task");
taskToExecute.getValue().run();

log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to ProcessingState" +
" and mark initialization future as complete");
when(mockState.state()).thenReturn(ShardConsumerState.PROCESSING);
consumer.executeLifecycle();

// Simulate the race where
// scheduler invokes executeLifecycle which performs Publisher.subscribe(subscriber)
// on recordProcessor thread
// but before scheduler thread finishes initialization, handleInput is invoked
// on record processor thread.

// Since ShardConsumer creates its own instance of subscriber that cannot be mocked
// this test sequence will appear a little odd.
// In order to control the order in which execution occurs, lets first invoke
// handleInput, although this will never happen, since there isn't a way
// to control the precise timing of the thread execution, this is the best way
final CountDownLatch processTaskLatch = new CountDownLatch(1);
new Thread(() -> {
reset(mockState);
when(mockState.taskType()).thenReturn(TaskType.PROCESS);
final ConsumerTask mockProcessTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
when(mockProcessTask.call()).then(input -> {
// first we want to wait for subscribe to be called,
// but we cannot control the timing, so wait for 10 seconds
// to let the main thread invoke executeLifecyle which
// will perform subscribe
processTaskLatch.countDown();
log.info("Record Processor Thread: Holding shardConsumer lock, waiting for 10 seconds to" +
" let subscribe be called by scheduler thread");
Thread.sleep(10 * 1000);
log.info("RecordProcessor Thread: Done waiting");
// then return shard end result
log.info("RecordProcessor Thread: Simulating execution of ProcessTask and returning shard-end result");
return new TaskResult(true);
});
final Subscription mockSubscription = mock(Subscription.class);
consumer.handleInput(ProcessRecordsInput.builder().isAtShardEnd(true).build(), mockSubscription);
}).start();

processTaskLatch.await();

// invoke executeLifecycle, which should invoke subscribe
// meanwhile if scheduler tries to acquire the ShardConsumer lock it will
// be blocked during initialization processing because handleInput was
// already invoked and will be holding the lock. Thereby creating the
// race condition we want.
reset(mockState);
AtomicBoolean successTransitionCalled = new AtomicBoolean(false);
when(mockState.successTransition()).then(input -> {
successTransitionCalled.set(true);
return mockState;
});
AtomicBoolean shutdownTransitionCalled = new AtomicBoolean(false);
when(mockState.shutdownTransition(any())).then(input -> {
shutdownTransitionCalled.set(true);
return mockState;
});
when(mockState.state()).then(input -> {
if (successTransitionCalled.get() && shutdownTransitionCalled.get()) {
return ShardConsumerState.SHUTTING_DOWN;
}
return ShardConsumerState.PROCESSING;
});
log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to invoke subscribe and" +
" complete initialization");
consumer.executeLifecycle();
log.info("Scheduler Thread: Done initializing the ShardConsumer");

log.info("Verifying scheduler did not perform shutdown transition during initialization");
verify(mockState, times(0)).shutdownTransition(any());
}

private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) {
mockSuccessfulShutdown(taskCallBarrier, null);
}
Expand Down

0 comments on commit 16e8404

Please sign in to comment.