Skip to content

Commit

Permalink
feat: take client request headers into account when deciding whether …
Browse files Browse the repository at this point in the history
…to reuse WS conns (#647)

Allow callers to indicate a list of headers (by name / regular
expression) that must be taken into account when deciding to reuse
connections.
  • Loading branch information
fiam authored Nov 6, 2023
1 parent 1950d2f commit c23f458
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 95 deletions.
55 changes: 44 additions & 11 deletions v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"

"github.com/buger/jsonparser"
"github.com/jensneuse/abstractlogger"
Expand Down Expand Up @@ -326,6 +327,16 @@ type SubscriptionConfiguration struct {
URL string
UseSSE bool
SSEMethodPost bool
// ForwardedClientHeaderNames indicates headers names that might be forwarded from the
// client to the upstream server. This is used to determine which connections
// can be multiplexed together, but the subscription engine does not forward
// these headers by itself.
ForwardedClientHeaderNames []string
// ForwardedClientHeaderRegularExpressions regular expressions that if matched to the header
// name might be forwarded from the client to the upstream server. This is used to determine
// which connections can be multiplexed together, but the subscription engine does not forward
// these headers by itself.
ForwardedClientHeaderRegularExpressions []*regexp.Regexp
}

type FetchConfiguration struct {
Expand Down Expand Up @@ -369,7 +380,7 @@ func (p *Planner) ConfigureFetch() resolve.FetchConfiguration {
input = httpclient.SetInputBodyWithPath(input, p.printOperation(), "query")

if p.unnulVariables {
input = httpclient.SetInputFlag(input, httpclient.UNNULLVARIABLES)
input = httpclient.SetInputFlag(input, httpclient.UNNULL_VARIABLES)
}

header, err := json.Marshal(p.config.Fetch.Header)
Expand Down Expand Up @@ -437,9 +448,9 @@ func (p *Planner) ConfigureSubscription() plan.SubscriptionConfiguration {
input = httpclient.SetInputBodyWithPath(input, p.printOperation(), "query")
input = httpclient.SetInputURL(input, []byte(p.config.Subscription.URL))
if p.config.Subscription.UseSSE {
input = httpclient.SetInputFlag(input, httpclient.USESSE)
input = httpclient.SetInputFlag(input, httpclient.USE_SSE)
if p.config.Subscription.SSEMethodPost {
input = httpclient.SetInputFlag(input, httpclient.SSEMETHODPOST)
input = httpclient.SetInputFlag(input, httpclient.SSE_METHOD_POST)
}
}

Expand All @@ -448,6 +459,26 @@ func (p *Planner) ConfigureSubscription() plan.SubscriptionConfiguration {
input = httpclient.SetInputHeader(input, header)
}

if len(p.config.Subscription.ForwardedClientHeaderNames) > 0 {
headers, err := json.Marshal(p.config.Subscription.ForwardedClientHeaderNames)
if err != nil {
// XXX: Since this is a very unlikely error, to avoid breaking
// the API we panic here
panic(err)
}
input = httpclient.SetForwardedClientHeaderNames(input, headers)
}

if len(p.config.Subscription.ForwardedClientHeaderRegularExpressions) > 0 {
headers, err := json.Marshal(p.config.Subscription.ForwardedClientHeaderRegularExpressions)
if err != nil {
// XXX: Since this is a very unlikely error, to avoid breaking
// the API we panic here
panic(err)
}
input = httpclient.SetForwardedClientHeaderRegularExpressions(input, headers)
}

return plan.SubscriptionConfiguration{
Input: string(input),
DataSource: &SubscriptionSource{
Expand Down Expand Up @@ -1651,7 +1682,7 @@ func (s *Source) compactAndUnNullVariables(input []byte) []byte {
variables = buf.Bytes()
}

removeNullVariables := httpclient.IsInputFlagSet(input, httpclient.UNNULLVARIABLES)
removeNullVariables := httpclient.IsInputFlagSet(input, httpclient.UNNULL_VARIABLES)
variables = s.cleanupVariables(variables, removeNullVariables, undefinedVariables)

input, _ = jsonparser.Set(input, variables, "body", "variables")
Expand Down Expand Up @@ -1722,15 +1753,17 @@ func (s *Source) Load(ctx context.Context, input []byte, writer io.Writer) (err
}

type GraphQLSubscriptionClient interface {
Subscribe(ctx context.Context, options GraphQLSubscriptionOptions, next chan<- []byte) error
Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, next chan<- []byte) error
}

type GraphQLSubscriptionOptions struct {
URL string `json:"url"`
Body GraphQLBody `json:"body"`
Header http.Header `json:"header"`
UseSSE bool `json:"use_sse"`
SSEMethodPost bool `json:"sse_method_post"`
URL string `json:"url"`
Body GraphQLBody `json:"body"`
Header http.Header `json:"header"`
UseSSE bool `json:"use_sse"`
SSEMethodPost bool `json:"sse_method_post"`
ForwardedClientHeaderNames []string `json:"forwarded_client_header_names"`
ForwardedClientHeaderRegularExpressions []*regexp.Regexp `json:"forwarded_client_header_regular_expressions"`
}

type GraphQLBody struct {
Expand All @@ -1744,7 +1777,7 @@ type SubscriptionSource struct {
client GraphQLSubscriptionClient
}

func (s *SubscriptionSource) Start(ctx context.Context, input []byte, next chan<- []byte) error {
func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, next chan<- []byte) error {
var options GraphQLSubscriptionOptions
err := json.Unmarshal(input, &options)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7996,7 +7996,7 @@ var errSubscriptionClientFail = errors.New("subscription client fail error")

type FailingSubscriptionClient struct{}

func (f FailingSubscriptionClient) Subscribe(_ context.Context, _ GraphQLSubscriptionOptions, _ chan<- []byte) error {
func (f FailingSubscriptionClient) Subscribe(_ *resolve.Context, _ GraphQLSubscriptionOptions, _ chan<- []byte) error {
return errSubscriptionClientFail
}

Expand Down Expand Up @@ -8043,13 +8043,13 @@ func TestSubscriptionSource_Start(t *testing.T) {

t.Run("should return error when input is invalid", func(t *testing.T) {
source := SubscriptionSource{client: FailingSubscriptionClient{}}
err := source.Start(context.Background(), []byte(`{"url": "", "body": "", "header": null}`), nil)
err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": "", "header": null}`), nil)
assert.Error(t, err)
})

t.Run("should return error when subscription client returns an error", func(t *testing.T) {
source := SubscriptionSource{client: FailingSubscriptionClient{}}
err := source.Start(context.Background(), []byte(`{"url": "", "body": {}, "header": null}`), nil)
err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": {}, "header": null}`), nil)
assert.Error(t, err)
assert.Equal(t, resolve.ErrUnableToResolve, err)
})
Expand All @@ -8061,7 +8061,7 @@ func TestSubscriptionSource_Start(t *testing.T) {

source := newSubscriptionSource(ctx.Context())
chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: "#test") { text createdBy } }"}`)
err := source.Start(ctx.Context(), chatSubscriptionOptions, next)
err := source.Start(ctx, chatSubscriptionOptions, next)
require.ErrorIs(t, err, resolve.ErrUnableToResolve)
})

Expand All @@ -8072,7 +8072,7 @@ func TestSubscriptionSource_Start(t *testing.T) {

source := newSubscriptionSource(ctx.Context())
chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`)
err := source.Start(ctx.Context(), chatSubscriptionOptions, next)
err := source.Start(ctx, chatSubscriptionOptions, next)
require.NoError(t, err)

msg, ok := <-next
Expand All @@ -8090,7 +8090,7 @@ func TestSubscriptionSource_Start(t *testing.T) {

source := newSubscriptionSource(resolverLifecycle)
chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`)
err := source.Start(subscriptionLifecycle, chatSubscriptionOptions, next)
err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, next)
require.NoError(t, err)

username := "myuser"
Expand All @@ -8111,7 +8111,7 @@ func TestSubscriptionSource_Start(t *testing.T) {

source := newSubscriptionSource(ctx.Context())
chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`)
err := source.Start(ctx.Context(), chatSubscriptionOptions, next)
err := source.Start(ctx, chatSubscriptionOptions, next)
require.NoError(t, err)

username := "myuser"
Expand Down Expand Up @@ -8173,7 +8173,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) {

source := newSubscriptionSource(ctx.Context())
chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`)
err := source.Start(ctx.Context(), chatSubscriptionOptions, next)
err := source.Start(ctx, chatSubscriptionOptions, next)
require.NoError(t, err)

msg, ok := <-next
Expand All @@ -8191,7 +8191,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) {

source := newSubscriptionSource(resolverLifecycle)
chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`)
err := source.Start(subscriptionLifecycle, chatSubscriptionOptions, next)
err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, next)
require.NoError(t, err)

username := "myuser"
Expand All @@ -8212,7 +8212,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) {

source := newSubscriptionSource(ctx.Context())
chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`)
err := source.Start(ctx.Context(), chatSubscriptionOptions, next)
err := source.Start(ctx, chatSubscriptionOptions, next)
require.NoError(t, err)

username := "myuser"
Expand Down Expand Up @@ -8364,7 +8364,7 @@ func TestSource_Load(t *testing.T) {
var input []byte
input = httpclient.SetInputBodyWithPath(input, variables, "variables")
input = httpclient.SetInputURL(input, []byte(serverUrl))
input = httpclient.SetInputFlag(input, httpclient.UNNULLVARIABLES)
input = httpclient.SetInputFlag(input, httpclient.UNNULL_VARIABLES)
buf := bytes.NewBuffer(nil)

require.NoError(t, src.Load(context.Background(), input, buf))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/buger/jsonparser"
log "github.com/jensneuse/abstractlogger"
"github.com/r3labs/sse/v2"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)

var (
Expand All @@ -29,10 +30,10 @@ type gqlSSEConnectionHandler struct {
options GraphQLSubscriptionOptions
}

func newSSEConnectionHandler(ctx context.Context, conn *http.Client, opts GraphQLSubscriptionOptions, l log.Logger) *gqlSSEConnectionHandler {
func newSSEConnectionHandler(ctx *resolve.Context, conn *http.Client, opts GraphQLSubscriptionOptions, l log.Logger) *gqlSSEConnectionHandler {
return &gqlSSEConnectionHandler{
conn: conn,
ctx: ctx,
ctx: ctx.Context(),
log: l,
options: opts,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)

func TestGraphQLSubscriptionClientSubscribe_SSE(t *testing.T) {
Expand Down Expand Up @@ -51,7 +52,8 @@ func TestGraphQLSubscriptionClientSubscribe_SSE(t *testing.T) {
)

next := make(chan []byte)
err := client.Subscribe(ctx, GraphQLSubscriptionOptions{

err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: server.URL,
Body: GraphQLBody{
Query: `subscription {messageAdded(roomName: "room"){text}}`,
Expand Down Expand Up @@ -89,7 +91,7 @@ func TestGraphQLSubscriptionClientSubscribe_SSE_RequestAbort(t *testing.T) {
)

next := make(chan []byte)
err := client.Subscribe(ctx, GraphQLSubscriptionOptions{
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: "http://dummy",
Body: GraphQLBody{
Query: `subscription {messageAdded(roomName: "room"){text}}`,
Expand Down Expand Up @@ -150,7 +152,7 @@ func TestGraphQLSubscriptionClientSubscribe_SSE_POST(t *testing.T) {
)

next := make(chan []byte)
err = client.Subscribe(ctx, GraphQLSubscriptionOptions{
err = client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: server.URL,
Body: postReqBody,
UseSSE: true,
Expand Down Expand Up @@ -208,7 +210,7 @@ func TestGraphQLSubscriptionClientSubscribe_SSE_WithEvents(t *testing.T) {
)

next := make(chan []byte)
err := client.Subscribe(ctx, GraphQLSubscriptionOptions{
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: server.URL,
Body: GraphQLBody{
Query: `subscription {messageAdded(roomName: "room"){text}}`,
Expand Down Expand Up @@ -262,7 +264,7 @@ func TestGraphQLSubscriptionClientSubscribe_SSE_Error(t *testing.T) {
)

next := make(chan []byte)
err := client.Subscribe(ctx, GraphQLSubscriptionOptions{
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: server.URL,
Body: GraphQLBody{
Query: `subscription {messageAdded(roomName: "room"){text}}`,
Expand Down Expand Up @@ -314,7 +316,7 @@ func TestGraphQLSubscriptionClientSubscribe_SSE_Error_Without_Header(t *testing.
)

next := make(chan []byte)
err := client.Subscribe(ctx, GraphQLSubscriptionOptions{
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: server.URL,
Body: GraphQLBody{
Query: `subscription {messageAdded(roomName: "room"){text}}`,
Expand Down Expand Up @@ -371,7 +373,7 @@ func TestGraphQLSubscriptionClientSubscribe_QueryParams(t *testing.T) {
)

next := make(chan []byte)
err := client.Subscribe(ctx, GraphQLSubscriptionOptions{
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: server.URL,
Body: GraphQLBody{
Query: `subscription($a: Int!){countdown(from: $a)}`,
Expand Down Expand Up @@ -497,7 +499,7 @@ func TestGraphQLSubscriptionClientSubscribe_SSE_Upstream_Dies(t *testing.T) {
)

next := make(chan []byte)
err := client.Subscribe(ctx, GraphQLSubscriptionOptions{
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
URL: server.URL,
Body: GraphQLBody{
Query: `subscription {messageAdded(roomName: "room"){text}}`,
Expand Down
Loading

0 comments on commit c23f458

Please sign in to comment.