Skip to content

Commit

Permalink
Add AddingTrailingDataSubscriber to allow users to send additional da…
Browse files Browse the repository at this point in the history
…ta t… (#4366)

* Add AdditionalDataSubscriber to allow users to send additional data to the downstream subscriber

* Support iterable
  • Loading branch information
zoewangg authored Sep 5, 2023
1 parent 34aa46d commit f8c1cb0
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.utils.async.AddingTrailingDataSubscriber;
import software.amazon.awssdk.utils.async.BufferingSubscriber;
import software.amazon.awssdk.utils.async.EventListeningSubscriber;
import software.amazon.awssdk.utils.async.FilteringSubscriber;
Expand Down Expand Up @@ -118,6 +120,18 @@ default SdkPublisher<T> limit(int limit) {
return subscriber -> subscribe(new LimitingSubscriber<>(subscriber, limit));
}


/**
* Creates a new publisher that emits trailing events provided by {@code trailingDataSupplier} in addition to the
* published events.
*
* @param trailingDataSupplier supplier to provide the trailing data
* @return New publisher that will publish additional events
*/
default SdkPublisher<T> addTrailingData(Supplier<Iterable<T>> trailingDataSupplier) {
return subscriber -> subscribe(new AddingTrailingDataSubscriber<T>(subscriber, trailingDataSupplier));
}

/**
* Add a callback that will be invoked after this publisher invokes {@link Subscriber#onComplete()}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -141,6 +142,23 @@ public void flatMapIterableHandlesError() {
.hasCause(exception);
}

@Test
public void addTrailingData_handlesCorrectly() {
FakeSdkPublisher<String> fakePublisher = new FakeSdkPublisher<>();

FakeStringSubscriber fakeSubscriber = new FakeStringSubscriber();
fakePublisher.addTrailingData(() -> Arrays.asList("two", "three"))
.subscribe(fakeSubscriber);

fakePublisher.publish("one");
fakePublisher.complete();

assertThat(fakeSubscriber.recordedEvents()).containsExactly("one", "two", "three");
assertThat(fakeSubscriber.isComplete()).isTrue();
assertThat(fakeSubscriber.isError()).isFalse();
}


private final static class FakeByteBufferSubscriber implements Subscriber<ByteBuffer> {
private final List<String> recordedEvents = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.utils.async;

import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

/**
* Allows to send trailing data before invoking onComplete on the downstream subscriber.
* trailingDataIterable will be created when the upstream subscriber has called onComplete.
*/
@SdkProtectedApi
public class AddingTrailingDataSubscriber<T> extends DelegatingSubscriber<T, T> {
private static final Logger log = Logger.loggerFor(AddingTrailingDataSubscriber.class);

/**
* The subscription to the upstream subscriber.
*/
private Subscription upstreamSubscription;

/**
* The amount of unfulfilled demand the downstream subscriber has opened against us.
*/
private final AtomicLong downstreamDemand = new AtomicLong(0);

/**
* Whether the upstream subscriber has called onComplete on us.
*/
private volatile boolean onCompleteCalledByUpstream = false;

/**
* Whether the upstream subscriber has called onError on us.
*/
private volatile boolean onErrorCalledByUpstream = false;

/**
* Whether we have called onComplete on the downstream subscriber.
*/
private volatile boolean onCompleteCalledOnDownstream = false;

private final Supplier<Iterable<T>> trailingDataIterableSupplier;
private Iterator<T> trailingDataIterator;

public AddingTrailingDataSubscriber(Subscriber<? super T> subscriber,
Supplier<Iterable<T>> trailingDataIterableSupplier) {
super(Validate.paramNotNull(subscriber, "subscriber"));
this.trailingDataIterableSupplier = Validate.paramNotNull(trailingDataIterableSupplier, "trailingDataIterableSupplier");
}

@Override
public void onSubscribe(Subscription subscription) {

if (upstreamSubscription != null) {
log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException());
subscription.cancel();
return;
}

upstreamSubscription = subscription;

subscriber.onSubscribe(new Subscription() {

@Override
public void request(long l) {
if (onErrorCalledByUpstream || onCompleteCalledOnDownstream) {
return;
}

addDownstreamDemand(l);

if (onCompleteCalledByUpstream) {
sendTrailingDataAndCompleteIfNeeded();
return;
}
upstreamSubscription.request(l);
}

@Override
public void cancel() {
upstreamSubscription.cancel();
}
});
}

@Override
public void onError(Throwable throwable) {
onErrorCalledByUpstream = true;
subscriber.onError(throwable);
}

@Override
public void onNext(T t) {
Validate.paramNotNull(t, "item");
downstreamDemand.decrementAndGet();
subscriber.onNext(t);
}

@Override
public void onComplete() {
onCompleteCalledByUpstream = true;
sendTrailingDataAndCompleteIfNeeded();
}

private void addDownstreamDemand(long l) {

if (l > 0) {
downstreamDemand.getAndUpdate(current -> {
long newValue = current + l;
return newValue >= 0 ? newValue : Long.MAX_VALUE;
});
} else {
upstreamSubscription.cancel();
onError(new IllegalArgumentException("Demand must not be negative"));
}
}

private synchronized void sendTrailingDataAndCompleteIfNeeded() {
if (onCompleteCalledOnDownstream) {
return;
}

if (trailingDataIterator == null) {
Iterable<T> supplier = trailingDataIterableSupplier.get();
if (supplier == null) {
completeDownstreamSubscriber();
return;
}

trailingDataIterator = supplier.iterator();
}

sendTrailingDataIfNeeded();

if (!trailingDataIterator.hasNext()) {
completeDownstreamSubscriber();
}
}

private void sendTrailingDataIfNeeded() {
long demand = downstreamDemand.get();

while (trailingDataIterator.hasNext() && demand > 0) {
subscriber.onNext(trailingDataIterator.next());
demand = downstreamDemand.decrementAndGet();
}
}

private void completeDownstreamSubscriber() {
subscriber.onComplete();
onCompleteCalledOnDownstream = true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.utils.async;

import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.reactivestreams.tck.SubscriberWhiteboxVerification;
import org.reactivestreams.tck.TestEnvironment;

public class AddingTrailingDataSubscriberTckTest extends SubscriberWhiteboxVerification<Integer> {
protected AddingTrailingDataSubscriberTckTest() {
super(new TestEnvironment());
}

@Override
public Subscriber<Integer> createSubscriber(WhiteboxSubscriberProbe<Integer> probe) {
Subscriber<Integer> foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>());

return new AddingTrailingDataSubscriber<Integer>(foo, () -> Arrays.asList(0, 1, 2)) {
@Override
public void onError(Throwable throwable) {
super.onError(throwable);
probe.registerOnError(throwable);
}

@Override
public void onSubscribe(Subscription subscription) {
super.onSubscribe(subscription);
probe.registerOnSubscribe(new SubscriberPuppet() {
@Override
public void triggerRequest(long elements) {
subscription.request(elements);
}

@Override
public void signalCancel() {
subscription.cancel();
}
});
}

@Override
public void onNext(Integer nextItem) {
super.onNext(nextItem);
probe.registerOnNext(nextItem);
}

@Override
public void onComplete() {
super.onComplete();
probe.registerOnComplete();
}
};
}

@Override
public Integer createElement(int i) {
return i;
}
}
Loading

0 comments on commit f8c1cb0

Please sign in to comment.