Skip to content

Commit

Permalink
Fixes Reactor context propagation, including tests
Browse files Browse the repository at this point in the history
  • Loading branch information
svametcalf committed Dec 28, 2023
1 parent 13df614 commit a11f716
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package com.salesforce.reactorgrpc.stub;

import io.grpc.CallOptions;
import io.grpc.stub.CallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.function.BiConsumer;
import java.util.function.Function;
Expand All @@ -34,9 +33,8 @@ public static <TRequest, TResponse> Mono<TResponse> oneToOne(
BiConsumer<TRequest, StreamObserver<TResponse>> delegate,
CallOptions options) {
try {
return Mono
.<TResponse>create(emitter -> monoSource.subscribe(
request -> delegate.accept(request, new StreamObserver<TResponse>() {
return monoSource.flatMap(r ->
Mono.<TResponse>create(emitter -> delegate.accept(r, new StreamObserver<TResponse>() {
@Override
public void onNext(TResponse tResponse) {
emitter.success(tResponse);
Expand All @@ -51,10 +49,10 @@ public void onError(Throwable throwable) {
public void onCompleted() {
// Do nothing
}
}),
emitter::error
))
.transform(Operators.lift(new SubscribeOnlyOnceLifter<TResponse>()));
})
)
.transform(Operators.lift(new SubscribeOnlyOnceLifter<TResponse>()))
);
} catch (Throwable throwable) {
return Mono.error(throwable);
}
Expand Down Expand Up @@ -97,17 +95,8 @@ public static <TRequest, TResponse> Mono<TResponse> manyToOne(
Function<StreamObserver<TResponse>, StreamObserver<TRequest>> delegate,
CallOptions options) {
try {
ReactorSubscriberAndClientProducer<TRequest> subscriberAndGRPCProducer =
fluxSource.subscribeWith(new ReactorSubscriberAndClientProducer<>());
ReactorClientStreamObserverAndPublisher<TResponse> observerAndPublisher =
new ReactorClientStreamObserverAndPublisher<>(
s -> subscriberAndGRPCProducer.subscribe((CallStreamObserver<TRequest>) s),
subscriberAndGRPCProducer::cancel
);

return Flux.from(observerAndPublisher)
.doOnSubscribe(s -> delegate.apply(observerAndPublisher))
.singleOrEmpty();
ReactorGrpcClientCallFlux<TRequest, TResponse> operator = new ReactorGrpcClientCallFlux<>(fluxSource, delegate);
return operator.doOnSubscribe(operator.onSubscribeHook()).singleOrEmpty();
} catch (Throwable throwable) {
return Mono.error(throwable);
}
Expand All @@ -123,19 +112,11 @@ public static <TRequest, TResponse> Flux<TResponse> manyToMany(
Function<StreamObserver<TResponse>, StreamObserver<TRequest>> delegate,
CallOptions options) {
try {

final int prefetch = ReactorCallOptions.getPrefetch(options);
final int lowTide = ReactorCallOptions.getLowTide(options);

ReactorSubscriberAndClientProducer<TRequest> subscriberAndGRPCProducer =
fluxSource.subscribeWith(new ReactorSubscriberAndClientProducer<>());
ReactorClientStreamObserverAndPublisher<TResponse> observerAndPublisher =
new ReactorClientStreamObserverAndPublisher<>(
s -> subscriberAndGRPCProducer.subscribe((CallStreamObserver<TRequest>) s),
subscriberAndGRPCProducer::cancel, prefetch, lowTide
);

return Flux.from(observerAndPublisher).doOnSubscribe(s -> delegate.apply(observerAndPublisher));
ReactorGrpcClientCallFlux<TRequest, TResponse> operator = new ReactorGrpcClientCallFlux<>(fluxSource, delegate, prefetch, lowTide);
return operator.doOnSubscribe(operator.onSubscribeHook());
} catch (Throwable throwable) {
return Flux.error(throwable);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2019, Salesforce.com, Inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/
package com.salesforce.reactorgrpc.stub;

import io.grpc.stub.CallStreamObserver;
import io.grpc.stub.StreamObserver;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxOperator;
import reactor.util.context.Context;

import java.util.function.Consumer;
import java.util.function.Function;

/**
* Create a {@link Flux} that allows for correct context propagation in client calls
*
* @param <TRequest>
* @param <TResponse>
*/
public class ReactorGrpcClientCallFlux<TRequest, TResponse> extends FluxOperator<TRequest, TResponse> {

private final ReactorSubscriberAndClientProducer<TRequest> requestConsumer;
private final ReactorClientStreamObserverAndPublisher<TResponse> responsePublisher;
private final Function<StreamObserver<TResponse>, StreamObserver<TRequest>> delegate;

ReactorGrpcClientCallFlux(Flux<TRequest> in, Function<StreamObserver<TResponse>, StreamObserver<TRequest>> delegate) {
super(in);
this.delegate = delegate;
this.requestConsumer = new ReactorSubscriberAndClientProducer<>();
this.responsePublisher = new ReactorClientStreamObserverAndPublisher<>(s -> requestConsumer.subscribe((CallStreamObserver<TRequest>) s), requestConsumer::cancel);
}

public ReactorGrpcClientCallFlux(Flux<TRequest> in, Function<StreamObserver<TResponse>, StreamObserver<TRequest>> delegate, int prefetch, int lowTide) {
super(in);
this.delegate = delegate;
this.requestConsumer = new ReactorSubscriberAndClientProducer<>();
this.responsePublisher = new ReactorClientStreamObserverAndPublisher<>(s -> requestConsumer.subscribe((CallStreamObserver<TRequest>) s), requestConsumer::cancel, prefetch, lowTide);
}

public Consumer<? super Subscription> onSubscribeHook() {
return s -> this.delegate.apply(responsePublisher);
}

@Override
public void subscribe(CoreSubscriber<? super TResponse> actual) {
responsePublisher.subscribe(actual);
source.subscribe(new CoreSubscriber<TRequest>() {
@Override
public void onSubscribe(Subscription s) {
requestConsumer.onSubscribe(s);
}

@Override
public void onNext(TRequest tRequest) {
requestConsumer.onNext(tRequest);
}

@Override
public void onError(Throwable throwable) {
requestConsumer.onError(throwable);
}

@Override
public void onComplete() {
requestConsumer.onComplete();
}

@Override
public Context currentContext() {
return actual.currentContext();
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
import reactor.core.Scannable;
import reactor.util.context.Context;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiFunction;
Expand All @@ -25,6 +26,11 @@ public class SubscribeOnlyOnceLifter<T> extends AtomicBoolean implements BiFunct
@Override
public CoreSubscriber<? super T> apply(Scannable scannable, CoreSubscriber<? super T> coreSubscriber) {
return new CoreSubscriber<T>() {
@Override
public Context currentContext() {
return coreSubscriber.currentContext();
}

@Override
public void onSubscribe(Subscription subscription) {
if (!compareAndSet(false, true)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright (c) 2019, Salesforce.com, Inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.reactorgrpc;

import io.grpc.testing.GrpcServerRule;
import org.junit.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

public class ReactorContextPropagationTest {

@Rule
public GrpcServerRule serverRule = new GrpcServerRule();

private static class SimpleGreeter extends ReactorGreeterGrpc.GreeterImplBase {
@Override
public Mono<HelloResponse> sayHello(Mono<HelloRequest> request) {
return request.map(HelloRequest::getName)
.map(name -> HelloResponse.newBuilder().setMessage("Hello " + name).build());
}

@Override
public Mono<HelloResponse> sayHelloReqStream(Flux<HelloRequest> request) {
return request.transformDeferredContextual((f, ctx) -> f.map(HelloRequest::getName))
.collect(Collectors.joining("and"))
.map(names -> HelloResponse.newBuilder().setMessage("Hello " + names).build());
}

@Override
public Flux<HelloResponse> sayHelloRespStream(Mono<HelloRequest> request) {
return request.repeat(2)
.map(HelloRequest::getName)
.zipWith(Flux.just("Hello ", "Hi ", "Greetings "), String::join)
.map(greeting -> HelloResponse.newBuilder().setMessage(greeting).build());
}

@Override
public Flux<HelloResponse> sayHelloBothStream(Flux<HelloRequest> request) {
return request.map(HelloRequest::getName)
.map(name -> HelloResponse.newBuilder().setMessage("Hello " + name).build());
}
}

@BeforeClass
public static void beforeAll(){
Hooks.enableContextLossTracking();
Hooks.onOperatorDebug();
}

@AfterClass
public static void afterAll(){
Hooks.disableContextLossTracking();
Hooks.resetOnOperatorDebug();
}

@Before
public void setup() {
serverRule.getServiceRegistry().addService(new SimpleGreeter());
}

@Test
public void oneToOne() {
ReactorGreeterGrpc.ReactorGreeterStub stub = ReactorGreeterGrpc.newReactorStub(serverRule.getChannel());
Mono<HelloRequest> req = Mono.just(HelloRequest.newBuilder().setName("reactor").build());

Mono<HelloResponse> resp = req
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.transform(stub::sayHello)
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.contextWrite(ctx -> ctx.put("name", "context"));

StepVerifier.create(resp)
.expectNextCount(1)
.verifyComplete();
}

@Test
public void oneToMany() {
ReactorGreeterGrpc.ReactorGreeterStub stub = ReactorGreeterGrpc.newReactorStub(serverRule.getChannel());
Mono<HelloRequest> req = Mono.just(HelloRequest.newBuilder().setName("reactor").build());

Flux<HelloResponse> resp = req
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.as(stub::sayHelloRespStream)
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.contextWrite(ctx -> ctx.put("name", "context"));

StepVerifier.create(resp)
.expectNextCount(3)
.verifyComplete();
}

@Test
public void manyToOne() {
ReactorGreeterGrpc.ReactorGreeterStub stub = ReactorGreeterGrpc.newReactorStub(serverRule.getChannel());
Flux<HelloRequest> req = Mono.deferContextual(ctx -> Mono.just(HelloRequest.newBuilder().setName(ctx.get("name")).build())).repeat(2);

Mono<HelloResponse> resp = req
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.as(stub::sayHelloReqStream)
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.contextWrite(ctx -> ctx.put("name", "context"));

StepVerifier.create(resp)
.expectAccessibleContext()
.contains("name", "context")
.then()
.expectNextCount(1)
.verifyComplete();
}

@Test
public void manyToMany() {
ReactorGreeterGrpc.ReactorGreeterStub stub = ReactorGreeterGrpc.newReactorStub(serverRule.getChannel());
Flux<HelloRequest> req = Mono.just(HelloRequest.newBuilder().setName("reactor").build()).repeat(2).contextWrite(c -> c.put("name", "boom"));

Flux<HelloResponse> resp = req
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.transform(stub::sayHelloBothStream)
.doOnEach(signal -> assertThat(signal.getContextView().getOrEmpty("name")).isNotEmpty())
.contextWrite(ctx -> ctx.put("name", "context"));

StepVerifier.create(resp)
.expectNextCount(3)
.verifyComplete();
}
}

0 comments on commit a11f716

Please sign in to comment.