Skip to content

Commit

Permalink
Correcting the behavior of gracefulShutdown (awslabs#1302)
Browse files Browse the repository at this point in the history
* modify ShutdownTask to call shutdownComplete for graceful shutdown

* add test to verify ShutdownTask succeeds regardless of shutdownNotification

* change access level for finalShutdownLatch to NONE

* remove unused variable in GracefulShutdownCoordinator

* make comment more concise

* move waitForFinalShutdown method into GracefulShutdownCoordinator class

* cleanup call method of GracefulShutdownCoordinator

* modify waitForFinalShutdown to throw InterruptedException
  • Loading branch information
vincentvilo-aws authored Apr 3, 2024
1 parent 581d713 commit 7f1f243
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@
*/
package software.amazon.kinesis.coordinator;

import lombok.Builder;
import lombok.Data;
import lombok.experimental.Accessors;

import java.util.concurrent.CountDownLatch;

@Data
@Builder
@Accessors(fluent = true)
class GracefulShutdownContext {
private final CountDownLatch shutdownCompleteLatch;
private final CountDownLatch notificationCompleteLatch;
private final CountDownLatch finalShutdownLatch;
private final Scheduler scheduler;

static GracefulShutdownContext SHUTDOWN_ALREADY_COMPLETED = new GracefulShutdownContext(null, null, null);

boolean isShutdownAlreadyCompleted() {
boolean isRecordProcessorShutdownComplete() {
return shutdownCompleteLatch == null && notificationCompleteLatch == null && scheduler == null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@

class GracefulShutdownCoordinator {

/**
* arbitrary wait time for worker's finalShutdown
*/
private static final long FINAL_SHUTDOWN_WAIT_TIME_SECONDS = 60L;

CompletableFuture<Boolean> startGracefulShutdown(Callable<Boolean> shutdownCallable) {
CompletableFuture<Boolean> cf = new CompletableFuture<>();
CompletableFuture.runAsync(() -> {
Expand Down Expand Up @@ -62,7 +67,18 @@ private String awaitingFinalShutdownMessage(GracefulShutdownContext context) {
return String.format("Waiting for %d record processors to complete final shutdown", outstanding);
}

/**
* used to wait for the worker's final shutdown to complete before returning the future for graceful shutdown
* @return true if the final shutdown is successful, false otherwise.
*/
private boolean waitForFinalShutdown(GracefulShutdownContext context) throws InterruptedException {
return context.finalShutdownLatch().await(FINAL_SHUTDOWN_WAIT_TIME_SECONDS, TimeUnit.SECONDS);
}

private boolean waitForRecordProcessors(GracefulShutdownContext context) {
if (context.isRecordProcessorShutdownComplete()) {
return true;
}

//
// Awaiting for all ShardConsumer/RecordProcessors to be notified that a shutdown has been requested.
Expand Down Expand Up @@ -148,14 +164,13 @@ private boolean workerShutdownWithRemaining(long outstanding, GracefulShutdownCo

@Override
public Boolean call() throws Exception {
GracefulShutdownContext context;
try {
context = startWorkerShutdown.call();
final GracefulShutdownContext context = startWorkerShutdown.call();
return waitForRecordProcessors(context) && waitForFinalShutdown(context);
} catch (Exception ex) {
log.warn("Caught exception while requesting initial worker shutdown.", ex);
throw ex;
}
return context.isShutdownAlreadyCompleted() || waitForRecordProcessors(context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ public class Scheduler implements Runnable {
* Used to ensure that only one requestedShutdown is in progress at a time.
*/
private CompletableFuture<Boolean> gracefulShutdownFuture;

/**
* CountDownLatch used by the GracefulShutdownCoordinator. Reaching zero means that
* the scheduler's finalShutdown() call has completed.
*/
@Getter(AccessLevel.NONE)
private final CountDownLatch finalShutdownLatch = new CountDownLatch(1);

@VisibleForTesting
protected boolean gracefuleShutdownStarted = false;

Expand Down Expand Up @@ -797,7 +805,7 @@ Callable<GracefulShutdownContext> createWorkerShutdownCallable() {
// If there are no leases notification is already completed, but we still need to shutdown the worker.
//
this.shutdown();
return GracefulShutdownContext.SHUTDOWN_ALREADY_COMPLETED;
return GracefulShutdownContext.builder().finalShutdownLatch(finalShutdownLatch).build();
}
CountDownLatch shutdownCompleteLatch = new CountDownLatch(leases.size());
CountDownLatch notificationCompleteLatch = new CountDownLatch(leases.size());
Expand All @@ -818,7 +826,12 @@ Callable<GracefulShutdownContext> createWorkerShutdownCallable() {
shutdownCompleteLatch.countDown();
}
}
return new GracefulShutdownContext(shutdownCompleteLatch, notificationCompleteLatch, this);
return GracefulShutdownContext.builder()
.shutdownCompleteLatch(shutdownCompleteLatch)
.notificationCompleteLatch(notificationCompleteLatch)
.finalShutdownLatch(finalShutdownLatch)
.scheduler(this)
.build();
};
}

Expand Down Expand Up @@ -878,6 +891,7 @@ private void finalShutdown() {
((CloudWatchMetricsFactory) metricsFactory).shutdown();
}
shutdownComplete = true;
finalShutdownLatch.countDown();
}

private List<ShardInfo> getShardInfoForAssignments() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ public ConsumerTask createTask(ShardConsumerArgument argument, ShardConsumer con
argument.shardRecordProcessor(),
argument.recordProcessorCheckpointer(),
consumer.shutdownReason(),
consumer.shutdownNotification(),
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
Expand Down Expand Up @@ -557,9 +558,6 @@ static class ShutdownCompleteState implements ConsumerState {

@Override
public ConsumerTask createTask(ShardConsumerArgument argument, ShardConsumer consumer, ProcessRecordsInput input) {
if (consumer.shutdownNotification() != null) {
consumer.shutdownNotification().shutdownComplete();
}
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public class ShutdownTask implements ConsumerTask {
private final ShardRecordProcessorCheckpointer recordProcessorCheckpointer;
@NonNull
private final ShutdownReason reason;
private final ShutdownNotification shutdownNotification;
@NonNull
private final InitialPositionInStreamExtended initialPositionInStream;
private final boolean cleanupLeasesOfCompletedShards;
Expand Down Expand Up @@ -149,6 +150,12 @@ public TaskResult call() {

log.debug("Shutting down retrieval strategy for shard {}.", leaseKey);
recordsPublisher.shutdown();

// shutdownNotification is only set and used when gracefulShutdown starts
if (shutdownNotification != null) {
shutdownNotification.shutdownComplete();
}

log.debug("Record processor completed shutdown() for shard {}", leaseKey);

return new TaskResult(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class GracefulShutdownCoordinatorTest {
@Mock
private CountDownLatch notificationCompleteLatch;
@Mock
private CountDownLatch finalShutdownLatch;
@Mock
private Scheduler scheduler;
@Mock
private Callable<GracefulShutdownContext> contextCallable;
Expand All @@ -57,6 +59,7 @@ public void testAllShutdownCompletedAlready() throws Exception {

when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);

assertThat(requestedShutdownCallable.call(), equalTo(true));
verify(shutdownCompleteLatch).await(anyLong(), any(TimeUnit.class));
Expand All @@ -72,6 +75,7 @@ public void testNotificationNotCompletedYet() throws Exception {
when(notificationCompleteLatch.getCount()).thenReturn(1L, 0L);
mockLatchAwait(shutdownCompleteLatch, true);
when(shutdownCompleteLatch.getCount()).thenReturn(1L, 1L, 0L);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);

when(scheduler.shutdownComplete()).thenReturn(false, true);
mockShardInfoConsumerMap(1, 0);
Expand All @@ -93,6 +97,7 @@ public void testShutdownNotCompletedYet() throws Exception {
mockLatchAwait(notificationCompleteLatch, true);
mockLatchAwait(shutdownCompleteLatch, false, true);
when(shutdownCompleteLatch.getCount()).thenReturn(1L, 0L);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);

when(scheduler.shutdownComplete()).thenReturn(false, true);
mockShardInfoConsumerMap(1, 0);
Expand All @@ -117,6 +122,8 @@ public void testMultipleAttemptsForNotification() throws Exception {
mockLatchAwait(shutdownCompleteLatch, true);
when(shutdownCompleteLatch.getCount()).thenReturn(2L, 2L, 1L, 1L, 0L);

when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);

when(scheduler.shutdownComplete()).thenReturn(false, false, false, true);
mockShardInfoConsumerMap(2, 1, 0);

Expand Down Expand Up @@ -286,6 +293,44 @@ public void testWorkerShutdownCallableThrows() throws Exception {
requestedShutdownCallable.call();
}

@Test
public void testShutdownFailsDueToRecordProcessors() throws Exception {
Callable<Boolean> requestedShutdownCallable = buildRequestedShutdownCallable();

when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(false);
when(shutdownCompleteLatch.getCount()).thenReturn(1L);
when(scheduler.shutdownComplete()).thenReturn(true);
mockShardInfoConsumerMap(1);

assertThat(requestedShutdownCallable.call(), equalTo(false));
verifyLatchAwait(shutdownCompleteLatch);
}

@Test
public void testShutdownFailsDueToWorker() throws Exception {
Callable<Boolean> requestedShutdownCallable = buildRequestedShutdownCallable();

when(notificationCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(shutdownCompleteLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(false);

assertThat(requestedShutdownCallable.call(), equalTo(false));
verifyLatchAwait(finalShutdownLatch);
}

/**
* tests that shutdown still succeeds in the case where there are no leases returned by the lease coordinator
*/
@Test
public void testShutdownSuccessWithNoLeases() throws Exception {
Callable<Boolean> requestedShutdownCallable = buildRequestedShutdownCallableWithNullLatches();
when(finalShutdownLatch.await(anyLong(), any(TimeUnit.class))).thenReturn(true);

assertThat(requestedShutdownCallable.call(), equalTo(true));
verifyLatchAwait(finalShutdownLatch);
}

private void verifyLatchAwait(CountDownLatch latch) throws Exception {
verifyLatchAwait(latch, times(1));
}
Expand All @@ -303,8 +348,24 @@ private void mockLatchAwait(CountDownLatch latch, Boolean initial, Boolean... re
}

private Callable<Boolean> buildRequestedShutdownCallable() throws Exception {
GracefulShutdownContext context = new GracefulShutdownContext(shutdownCompleteLatch,
notificationCompleteLatch, scheduler);
GracefulShutdownContext context = GracefulShutdownContext.builder()
.shutdownCompleteLatch(shutdownCompleteLatch)
.notificationCompleteLatch(notificationCompleteLatch)
.finalShutdownLatch(finalShutdownLatch)
.scheduler(scheduler)
.build();
when(contextCallable.call()).thenReturn(context);
return new GracefulShutdownCoordinator().createGracefulShutdownCallable(contextCallable);
}

/**
* finalShutdownLatch will always be initialized, but shutdownCompleteLatch and notificationCompleteLatch are not
* initialized in the case where there are no leases returned by the lease coordinator
*/
private Callable<Boolean> buildRequestedShutdownCallableWithNullLatches() throws Exception {
GracefulShutdownContext context = GracefulShutdownContext.builder()
.finalShutdownLatch(finalShutdownLatch)
.build();
when(contextCallable.call()).thenReturn(context);
return new GracefulShutdownCoordinator().createGracefulShutdownCallable(contextCallable);
}
Expand All @@ -319,4 +380,5 @@ private void mockShardInfoConsumerMap(Integer initialItemCount, Integer... addit
when(shardInfoConsumerMap.isEmpty()).thenReturn(initialItemCount == 0, additionalEmptyStates);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState;

Expand Down Expand Up @@ -355,28 +352,17 @@ public void shutdownCompleteStateTest() {
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();

assertThat(state.createTask(argument, consumer, null), nullValue());
verify(consumer, times(2)).shutdownNotification();
verify(shutdownNotification).shutdownComplete();

assertThat(state.successTransition(), equalTo(state));
for (ShutdownReason reason : ShutdownReason.values()) {
assertThat(state.shutdownTransition(reason), equalTo(state));
}

assertThat(state.isTerminal(), equalTo(true));
assertThat(state.state(), equalTo(ShardConsumerState.SHUTDOWN_COMPLETE));
assertThat(state.taskType(), equalTo(TaskType.SHUTDOWN_COMPLETE));
}

@Test
public void shutdownCompleteStateNullNotificationTest() {
ConsumerState state = ShardConsumerState.SHUTDOWN_COMPLETE.consumerState();

when(consumer.shutdownNotification()).thenReturn(null);
assertThat(state.createTask(argument, consumer, null), nullValue());

verify(consumer).shutdownNotification();
verify(shutdownNotification, never()).shutdownComplete();
}

static <ValueType> ReflectionPropertyMatcher<ShutdownTask, ValueType> shutdownTask(Class<ValueType> valueTypeClass,
String propertyName, Matcher<ValueType> matcher) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ public class ShutdownTaskTest {
private ShardRecordProcessor shardRecordProcessor;
@Mock
private LeaseCleanupManager leaseCleanupManager;
@Mock
private ShutdownNotification shutdownNotification;

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -308,6 +310,26 @@ public void testNullChildShards() throws Exception {
verify(leaseRefresher, never()).createLeaseIfNotExists(any(Lease.class));
}

/**
* shutdownNotification is only set when ShardConsumer.gracefulShutdown() is called and should be null otherwise.
* The task should still call recordsPublisher.shutdown() regardless of the notification
*/
@Test
public void testCallWhenShutdownNotificationIsSet() {
final TaskResult result = createShutdownTaskWithNotification(LEASE_LOST, Collections.emptyList()).call();
assertNull(result.getException());
verify(recordsPublisher).shutdown();
verify(shutdownNotification).shutdownComplete();
}

@Test
public void testCallWhenShutdownNotificationIsNull() {
final TaskResult result = createShutdownTask(LEASE_LOST, Collections.emptyList()).call();
assertNull(result.getException());
verify(recordsPublisher).shutdown();
verify(shutdownNotification, never()).shutdownComplete();
}

/**
* Test method for {@link ShutdownTask#taskType()}.
*/
Expand Down Expand Up @@ -372,7 +394,15 @@ private ShutdownTask createShutdownTask(final ShutdownReason reason, final List<
private ShutdownTask createShutdownTask(final ShutdownReason reason, final List<ChildShard> childShards,
final ShardInfo shardInfo) {
return new ShutdownTask(shardInfo, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
reason, INITIAL_POSITION_TRIM_HORIZON, false, false,
reason, null, INITIAL_POSITION_TRIM_HORIZON, false, false,
leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, hierarchicalShardSyncer,
NULL_METRICS_FACTORY, childShards, STREAM_IDENTIFIER, leaseCleanupManager);
}

private ShutdownTask createShutdownTaskWithNotification(final ShutdownReason reason,
final List<ChildShard> childShards) {
return new ShutdownTask(SHARD_INFO, shardDetector, shardRecordProcessor, recordProcessorCheckpointer,
reason, shutdownNotification, INITIAL_POSITION_TRIM_HORIZON, false, false,
leaseCoordinator, TASK_BACKOFF_TIME_MILLIS, recordsPublisher, hierarchicalShardSyncer,
NULL_METRICS_FACTORY, childShards, STREAM_IDENTIFIER, leaseCleanupManager);
}
Expand Down

0 comments on commit 7f1f243

Please sign in to comment.