From 61784bfd4998422866af27c35e836ec9dbbe7f66 Mon Sep 17 00:00:00 2001 From: spetrunin Date: Wed, 27 Sep 2023 12:43:08 +0300 Subject: [PATCH 1/5] backport subscriptions refactoring from v1 to v2(#605) ad5d5159 --- v2/pkg/subscription/constants.go | 7 + v2/pkg/subscription/context.go | 15 +- v2/pkg/subscription/context_test.go | 25 +- v2/pkg/subscription/engine.go | 215 +++++ v2/pkg/subscription/engine_mock_test.go | 77 ++ v2/pkg/subscription/engine_test.go | 512 ++++++++++ v2/pkg/subscription/executor.go | 24 + v2/pkg/subscription/executor_mock_test.go | 141 +++ v2/pkg/subscription/handler.go | 588 +++--------- v2/pkg/subscription/handler_mock_test.go | 98 ++ v2/pkg/subscription/handler_test.go | 878 ++++++------------ v2/pkg/subscription/init.go | 13 +- v2/pkg/subscription/legacy_handler.go | 494 ++++++++++ v2/pkg/subscription/legacy_handler_test.go | 670 +++++++++++++ v2/pkg/subscription/time_out.go | 38 + v2/pkg/subscription/time_out_test.go | 59 ++ v2/pkg/subscription/transport_client.go | 27 + .../transport_client_mock_test.go | 105 +++ v2/pkg/subscription/websocket/client.go | 187 ++++ v2/pkg/subscription/websocket/client_test.go | 461 +++++++++ .../websocket/engine_mock_test.go | 78 ++ v2/pkg/subscription/websocket/handler.go | 228 +++++ v2/pkg/subscription/websocket/handler_test.go | 344 +++++++ v2/pkg/subscription/websocket/init.go | 47 + .../protocol_graphql_transport_ws.go | 524 +++++++++++ .../protocol_graphql_transport_ws_test.go | 569 ++++++++++++ .../websocket/protocol_graphql_ws.go | 359 +++++++ .../websocket/protocol_graphql_ws_test.go | 411 ++++++++ 28 files changed, 6163 insertions(+), 1031 deletions(-) create mode 100644 v2/pkg/subscription/constants.go create mode 100644 v2/pkg/subscription/engine.go create mode 100644 v2/pkg/subscription/engine_mock_test.go create mode 100644 v2/pkg/subscription/engine_test.go create mode 100644 v2/pkg/subscription/executor.go create mode 100644 v2/pkg/subscription/executor_mock_test.go create mode 100644 v2/pkg/subscription/handler_mock_test.go create mode 100644 v2/pkg/subscription/legacy_handler.go create mode 100644 v2/pkg/subscription/legacy_handler_test.go create mode 100644 v2/pkg/subscription/time_out.go create mode 100644 v2/pkg/subscription/time_out_test.go create mode 100644 v2/pkg/subscription/transport_client.go create mode 100644 v2/pkg/subscription/transport_client_mock_test.go create mode 100644 v2/pkg/subscription/websocket/client.go create mode 100644 v2/pkg/subscription/websocket/client_test.go create mode 100644 v2/pkg/subscription/websocket/engine_mock_test.go create mode 100644 v2/pkg/subscription/websocket/handler.go create mode 100644 v2/pkg/subscription/websocket/handler_test.go create mode 100644 v2/pkg/subscription/websocket/init.go create mode 100644 v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go create mode 100644 v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go create mode 100644 v2/pkg/subscription/websocket/protocol_graphql_ws.go create mode 100644 v2/pkg/subscription/websocket/protocol_graphql_ws_test.go diff --git a/v2/pkg/subscription/constants.go b/v2/pkg/subscription/constants.go new file mode 100644 index 000000000..2ed5557bc --- /dev/null +++ b/v2/pkg/subscription/constants.go @@ -0,0 +1,7 @@ +package subscription + +const ( + DefaultKeepAliveInterval = "15s" + DefaultSubscriptionUpdateInterval = "1s" + DefaultReadErrorTimeOut = "5s" +) diff --git a/v2/pkg/subscription/context.go b/v2/pkg/subscription/context.go index b71a77b5c..13b9acbf7 100644 --- a/v2/pkg/subscription/context.go +++ b/v2/pkg/subscription/context.go @@ -2,10 +2,16 @@ package subscription import ( "context" + "errors" + "fmt" "net/http" "sync" ) +var ( + ErrSubscriberIDAlreadyExists = errors.New("subscriber id already exists") +) + type InitialHttpRequestContext struct { context.Context Request *http.Request @@ -23,15 +29,18 @@ type subscriptionCancellations struct { cancellations map[string]context.CancelFunc } -func (sc *subscriptionCancellations) AddWithParent(id string, parent context.Context) context.Context { - ctx, cancelFunc := context.WithCancel(parent) +func (sc *subscriptionCancellations) AddWithParent(id string, parent context.Context) (context.Context, error) { sc.mu.Lock() defer sc.mu.Unlock() if sc.cancellations == nil { sc.cancellations = make(map[string]context.CancelFunc) } + if _, ok := sc.cancellations[id]; ok { + return nil, fmt.Errorf("%w: %s", ErrSubscriberIDAlreadyExists, id) + } + ctx, cancelFunc := context.WithCancel(parent) sc.cancellations[id] = cancelFunc - return ctx + return ctx, nil } func (sc *subscriptionCancellations) Cancel(id string) (ok bool) { diff --git a/v2/pkg/subscription/context_test.go b/v2/pkg/subscription/context_test.go index 8858533b3..ebe72f89c 100644 --- a/v2/pkg/subscription/context_test.go +++ b/v2/pkg/subscription/context_test.go @@ -26,11 +26,13 @@ func TestNewInitialHttpRequestContext(t *testing.T) { func TestSubscriptionCancellations(t *testing.T) { cancellations := subscriptionCancellations{} var ctx context.Context + var err error t.Run("should add a cancellation func to map", func(t *testing.T) { require.Equal(t, 0, cancellations.Len()) - ctx = cancellations.AddWithParent("1", context.Background()) + ctx, err = cancellations.AddWithParent("1", context.Background()) + assert.Nil(t, err) assert.Equal(t, 1, cancellations.Len()) assert.NotNil(t, ctx) }) @@ -48,3 +50,24 @@ func TestSubscriptionCancellations(t *testing.T) { assert.Equal(t, 0, cancellations.Len()) }) } + +func TestSubscriptionIdsShouldBeUnique(t *testing.T) { + sc := subscriptionCancellations{} + var ctx context.Context + var err error + + ctx, err = sc.AddWithParent("1", context.Background()) + assert.Nil(t, err) + assert.Equal(t, 1, sc.Len()) + assert.NotNil(t, ctx) + + ctx, err = sc.AddWithParent("2", context.Background()) + assert.Nil(t, err) + assert.Equal(t, 2, sc.Len()) + assert.NotNil(t, ctx) + + ctx, err = sc.AddWithParent("2", context.Background()) + assert.NotNil(t, err) + assert.Equal(t, 2, sc.Len()) + assert.Nil(t, ctx) +} diff --git a/v2/pkg/subscription/engine.go b/v2/pkg/subscription/engine.go new file mode 100644 index 000000000..313bfef9f --- /dev/null +++ b/v2/pkg/subscription/engine.go @@ -0,0 +1,215 @@ +package subscription + +//go:generate mockgen -destination=engine_mock_test.go -package=subscription . Engine +//go:generate mockgen -destination=websocket/engine_mock_test.go -package=websocket . Engine + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/graphql" +) + +type errOnBeforeStartHookFailure struct { + wrappedErr error +} + +func (e *errOnBeforeStartHookFailure) Unwrap() error { + return e.wrappedErr +} + +func (e *errOnBeforeStartHookFailure) Error() string { + return fmt.Sprintf("on before start hook failed: %s", e.wrappedErr.Error()) +} + +// Engine defines the function for a subscription engine. +type Engine interface { + StartOperation(ctx context.Context, id string, payload []byte, eventHandler EventHandler) error + StopSubscription(id string, eventHandler EventHandler) error + TerminateAllSubscriptions(eventHandler EventHandler) error +} + +// ExecutorEngine is an implementation of Engine and works with subscription.Executor. +type ExecutorEngine struct { + logger abstractlogger.Logger + // subCancellations is map containing the cancellation functions to every active subscription. + subCancellations subscriptionCancellations + // executorPool is responsible to create and hold executors. + executorPool ExecutorPool + // bufferPool will hold buffers. + bufferPool *sync.Pool + // subscriptionUpdateInterval is the actual interval on which the server sends subscription updates to the client. + subscriptionUpdateInterval time.Duration +} + +// StartOperation will start any operation. +func (e *ExecutorEngine) StartOperation(ctx context.Context, id string, payload []byte, eventHandler EventHandler) error { + executor, err := e.executorPool.Get(payload) + if err != nil { + return err + } + + if err = e.handleOnBeforeStart(executor); err != nil { + eventHandler.Emit(EventTypeOnError, id, nil, err) + return &errOnBeforeStartHookFailure{wrappedErr: err} + } + + if ctx, err = e.checkForDuplicateSubscriberID(ctx, id, eventHandler); err != nil { + return err + } + + if executor.OperationType() == ast.OperationTypeSubscription { + go e.startSubscription(ctx, id, executor, eventHandler) + return nil + } + + go e.handleNonSubscriptionOperation(ctx, id, executor, eventHandler) + return nil +} + +// StopSubscription will stop an active subscription. +func (e *ExecutorEngine) StopSubscription(id string, eventHandler EventHandler) error { + e.subCancellations.Cancel(id) + eventHandler.Emit(EventTypeOnSubscriptionCompleted, id, nil, nil) + return nil +} + +// TerminateAllSubscriptions will cancel all active subscriptions. +func (e *ExecutorEngine) TerminateAllSubscriptions(eventHandler EventHandler) error { + if e.subCancellations.Len() == 0 { + return nil + } + + for id := range e.subCancellations.cancellations { + e.subCancellations.Cancel(id) + } + + eventHandler.Emit(EventTypeOnConnectionTerminatedByServer, "", []byte("connection terminated by server"), nil) + return nil +} + +func (e *ExecutorEngine) handleOnBeforeStart(executor Executor) error { + switch e := executor.(type) { + case *ExecutorV2: + if hook := e.engine.GetWebsocketBeforeStartHook(); hook != nil { + return hook.OnBeforeStart(e.reqCtx, e.operation) + } + case *ExecutorV1: + // do nothing + } + + return nil +} + +func (e *ExecutorEngine) checkForDuplicateSubscriberID(ctx context.Context, id string, eventHandler EventHandler) (context.Context, error) { + ctx, subsErr := e.subCancellations.AddWithParent(id, ctx) + if errors.Is(subsErr, ErrSubscriberIDAlreadyExists) { + eventHandler.Emit(EventTypeOnDuplicatedSubscriberID, id, nil, subsErr) + return ctx, subsErr + } else if subsErr != nil { + eventHandler.Emit(EventTypeOnError, id, nil, subsErr) + return ctx, subsErr + } + return ctx, nil +} + +func (e *ExecutorEngine) startSubscription(ctx context.Context, id string, executor Executor, eventHandler EventHandler) { + defer func() { + err := e.executorPool.Put(executor) + if err != nil { + e.logger.Error("subscription.Handle.startSubscription()", + abstractlogger.Error(err), + ) + } + }() + + executor.SetContext(ctx) + buf := e.bufferPool.Get().(*graphql.EngineResultWriter) + buf.Reset() + + defer e.bufferPool.Put(buf) + + e.executeSubscription(buf, id, executor, eventHandler) + + for { + buf.Reset() + select { + case <-ctx.Done(): + return + case <-time.After(e.subscriptionUpdateInterval): + e.executeSubscription(buf, id, executor, eventHandler) + } + } + +} + +func (e *ExecutorEngine) executeSubscription(buf *graphql.EngineResultWriter, id string, executor Executor, eventHandler EventHandler) { + buf.SetFlushCallback(func(data []byte) { + e.logger.Debug("subscription.Handle.executeSubscription()", + abstractlogger.ByteString("execution_result", data), + ) + eventHandler.Emit(EventTypeOnSubscriptionData, id, data, nil) + }) + defer buf.SetFlushCallback(nil) + + err := executor.Execute(buf) + if err != nil { + e.logger.Error("subscription.Handle.executeSubscription()", + abstractlogger.Error(err), + ) + + eventHandler.Emit(EventTypeOnError, id, nil, err) + return + } + + if buf.Len() > 0 { + data := buf.Bytes() + e.logger.Debug("subscription.Handle.executeSubscription()", + abstractlogger.ByteString("execution_result", data), + ) + eventHandler.Emit(EventTypeOnSubscriptionData, id, data, nil) + } +} + +func (e *ExecutorEngine) handleNonSubscriptionOperation(ctx context.Context, id string, executor Executor, eventHandler EventHandler) { + defer func() { + e.subCancellations.Cancel(id) + err := e.executorPool.Put(executor) + if err != nil { + e.logger.Error("subscription.Handle.handleNonSubscriptionOperation()", + abstractlogger.Error(err), + ) + } + }() + + executor.SetContext(ctx) + buf := e.bufferPool.Get().(*graphql.EngineResultWriter) + buf.Reset() + + defer e.bufferPool.Put(buf) + + err := executor.Execute(buf) + if err != nil { + e.logger.Error("subscription.Handle.handleNonSubscriptionOperation()", + abstractlogger.Error(err), + ) + + eventHandler.Emit(EventTypeOnError, id, nil, err) + return + } + + e.logger.Debug("subscription.Handle.handleNonSubscriptionOperation()", + abstractlogger.ByteString("execution_result", buf.Bytes()), + ) + + eventHandler.Emit(EventTypeOnNonSubscriptionExecutionResult, id, buf.Bytes(), err) +} + +// Interface Guards +var _ Engine = (*ExecutorEngine)(nil) diff --git a/v2/pkg/subscription/engine_mock_test.go b/v2/pkg/subscription/engine_mock_test.go new file mode 100644 index 000000000..fb5cd659b --- /dev/null +++ b/v2/pkg/subscription/engine_mock_test.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Engine) + +// Package subscription is a generated GoMock package. +package subscription + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockEngine is a mock of Engine interface. +type MockEngine struct { + ctrl *gomock.Controller + recorder *MockEngineMockRecorder +} + +// MockEngineMockRecorder is the mock recorder for MockEngine. +type MockEngineMockRecorder struct { + mock *MockEngine +} + +// NewMockEngine creates a new mock instance. +func NewMockEngine(ctrl *gomock.Controller) *MockEngine { + mock := &MockEngine{ctrl: ctrl} + mock.recorder = &MockEngineMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEngine) EXPECT() *MockEngineMockRecorder { + return m.recorder +} + +// StartOperation mocks base method. +func (m *MockEngine) StartOperation(arg0 context.Context, arg1 string, arg2 []byte, arg3 EventHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartOperation", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// StartOperation indicates an expected call of StartOperation. +func (mr *MockEngineMockRecorder) StartOperation(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartOperation", reflect.TypeOf((*MockEngine)(nil).StartOperation), arg0, arg1, arg2, arg3) +} + +// StopSubscription mocks base method. +func (m *MockEngine) StopSubscription(arg0 string, arg1 EventHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StopSubscription", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// StopSubscription indicates an expected call of StopSubscription. +func (mr *MockEngineMockRecorder) StopSubscription(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopSubscription", reflect.TypeOf((*MockEngine)(nil).StopSubscription), arg0, arg1) +} + +// TerminateAllSubscriptions mocks base method. +func (m *MockEngine) TerminateAllSubscriptions(arg0 EventHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TerminateAllSubscriptions", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// TerminateAllSubscriptions indicates an expected call of TerminateAllSubscriptions. +func (mr *MockEngineMockRecorder) TerminateAllSubscriptions(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateAllSubscriptions", reflect.TypeOf((*MockEngine)(nil).TerminateAllSubscriptions), arg0) +} diff --git a/v2/pkg/subscription/engine_test.go b/v2/pkg/subscription/engine_test.go new file mode 100644 index 000000000..5ad01ae34 --- /dev/null +++ b/v2/pkg/subscription/engine_test.go @@ -0,0 +1,512 @@ +package subscription + +import ( + "bytes" + "context" + "errors" + "runtime" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/graphql" +) + +func TestExecutorEngine_StartOperation(t *testing.T) { + t.Run("execute non-subscription operation", func(t *testing.T) { + t.Run("on execution failure", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancelFunc := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancelFunc() + + idQuery := "1" + payloadQuery := []byte(`{"query":"{ hello }"}`) + + idMutation := "2" + payloadMutation := []byte(`{"query":"mutation { update }"}`) + + executorMock := NewMockExecutor(ctrl) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeQuery). + Times(1) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeMutation). + Times(1) + executorMock.EXPECT().SetContext(assignableToContextWithCancel(ctx)). + Times(2) + executorMock.EXPECT().Execute(gomock.AssignableToTypeOf(&graphql.EngineResultWriter{})). + Return(errors.New("error")). + Times(2) + + executorPoolMock := NewMockExecutorPool(ctrl) + executorPoolMock.EXPECT().Get(gomock.Eq(payloadQuery)). + Return(executorMock, nil). + Times(1) + executorPoolMock.EXPECT().Get(gomock.Eq(payloadMutation)). + Return(executorMock, nil). + Times(1) + executorPoolMock.EXPECT().Put(gomock.Eq(executorMock)). + Do(func(_ Executor) { + wg.Done() + }). + Times(2) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnError), gomock.Eq(idQuery), gomock.Nil(), gomock.Any()). + Times(1) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnError), gomock.Eq(idMutation), gomock.Nil(), gomock.Any()). + Times(1) + + engine := ExecutorEngine{ + logger: abstractlogger.Noop{}, + subCancellations: subscriptionCancellations{}, + executorPool: executorPoolMock, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + subscriptionUpdateInterval: 0, + } + + assert.Eventually(t, func() bool { + err := engine.StartOperation(ctx, idQuery, payloadQuery, eventHandlerMock) + assert.NoError(t, err) + + err = engine.StartOperation(ctx, idMutation, payloadMutation, eventHandlerMock) + assert.NoError(t, err) + + <-ctx.Done() + wg.Wait() + return true + }, 1*time.Second, 10*time.Millisecond) + }) + + t.Run("on execution success", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancelFunc := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancelFunc() + + idQuery := "1" + payloadQuery := []byte(`{"query":"{ hello }"}`) + + idMutation := "2" + payloadMutation := []byte(`{"query":"mutation { update }"}`) + + executorMock := NewMockExecutor(ctrl) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeQuery). + Times(1) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeMutation). + Times(1) + executorMock.EXPECT().SetContext(assignableToContextWithCancel(ctx)). + Times(2) + executorMock.EXPECT().Execute(gomock.AssignableToTypeOf(&graphql.EngineResultWriter{})). + Times(2) + + executorPoolMock := NewMockExecutorPool(ctrl) + executorPoolMock.EXPECT().Get(gomock.Eq(payloadQuery)). + Return(executorMock, nil). + Times(1) + executorPoolMock.EXPECT().Get(gomock.Eq(payloadMutation)). + Return(executorMock, nil). + Times(1) + executorPoolMock.EXPECT().Put(gomock.Eq(executorMock)). + Do(func(_ Executor) { + wg.Done() + }). + Times(2) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnNonSubscriptionExecutionResult), gomock.Eq(idQuery), gomock.AssignableToTypeOf([]byte{}), gomock.Nil()). + Times(1) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnNonSubscriptionExecutionResult), gomock.Eq(idMutation), gomock.AssignableToTypeOf([]byte{}), gomock.Nil()). + Times(1) + + engine := ExecutorEngine{ + logger: abstractlogger.Noop{}, + subCancellations: subscriptionCancellations{}, + executorPool: executorPoolMock, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + subscriptionUpdateInterval: 0, + } + + assert.Eventually(t, func() bool { + err := engine.StartOperation(ctx, idQuery, payloadQuery, eventHandlerMock) + assert.NoError(t, err) + + err = engine.StartOperation(ctx, idMutation, payloadMutation, eventHandlerMock) + assert.NoError(t, err) + + <-ctx.Done() + wg.Wait() + return true + }, 1*time.Second, 10*time.Millisecond) + }) + }) + + t.Run("execute subscription operation", func(t *testing.T) { + t.Run("on execution failure", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("this test fails on Windows due to different timings than unix, consider fixing it at some point") + } + wg := &sync.WaitGroup{} + wg.Add(1) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancelFunc := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancelFunc() + + id := "1" + payload := []byte(`{"query":"subscription { receiveData }"}`) + + executorMock := NewMockExecutor(ctrl) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeSubscription). + Times(1) + executorMock.EXPECT().SetContext(assignableToContextWithCancel(ctx)). + Times(1) + executorMock.EXPECT().Execute(gomock.AssignableToTypeOf(&graphql.EngineResultWriter{})). + Return(errors.New("error")). + MinTimes(2) + + executorPoolMock := NewMockExecutorPool(ctrl) + executorPoolMock.EXPECT().Get(gomock.Eq(payload)). + Return(executorMock, nil). + Times(1) + executorPoolMock.EXPECT().Put(gomock.Eq(executorMock)). + Do(func(_ Executor) { + wg.Done() + }). + Times(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnError), gomock.Eq(id), gomock.Nil(), gomock.Any()). + MinTimes(2) + + engine := ExecutorEngine{ + logger: abstractlogger.Noop{}, + subCancellations: subscriptionCancellations{}, + executorPool: executorPoolMock, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + subscriptionUpdateInterval: 2 * time.Millisecond, + } + + assert.Eventually(t, func() bool { + err := engine.StartOperation(ctx, id, payload, eventHandlerMock) + <-ctx.Done() + wg.Wait() + return assert.NoError(t, err) + }, 1*time.Second, 10*time.Millisecond) + }) + + t.Run("on execution success", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("this test fails on Windows due to different timings than unix, consider fixing it at some point") + } + + wg := sync.WaitGroup{} + wg.Add(1) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancelFunc := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancelFunc() + + id := "1" + payload := []byte(`{"query":"subscription { receiveData }"}`) + + executorMock := NewMockExecutor(ctrl) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeSubscription). + Times(1) + executorMock.EXPECT().SetContext(assignableToContextWithCancel(ctx)). + Times(1) + executorMock.EXPECT().Execute(gomock.AssignableToTypeOf(&graphql.EngineResultWriter{})). + Do(func(resultWriter *graphql.EngineResultWriter) { + _, _ = resultWriter.Write([]byte(`{ "data": { "update": "newData" } }`)) + }). + MinTimes(2) + + executorPoolMock := NewMockExecutorPool(ctrl) + executorPoolMock.EXPECT().Get(gomock.Eq(payload)). + Return(executorMock, nil). + Times(1) + executorPoolMock.EXPECT().Put(gomock.Eq(executorMock)). + Do(func(_ Executor) { + wg.Done() + }). + Times(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnSubscriptionData), gomock.Eq(id), gomock.AssignableToTypeOf([]byte{}), gomock.Nil()). + MinTimes(2) + + engine := ExecutorEngine{ + logger: abstractlogger.Noop{}, + subCancellations: subscriptionCancellations{}, + executorPool: executorPoolMock, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + subscriptionUpdateInterval: 2 * time.Millisecond, + } + + assert.Eventually(t, func() bool { + err := engine.StartOperation(ctx, id, payload, eventHandlerMock) + <-ctx.Done() + wg.Wait() + return assert.NoError(t, err) + }, 1*time.Second, 10*time.Millisecond) + }) + }) + + t.Run("error on duplicate id", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancelFunc := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancelFunc() + + id := "1" + payloadSubscription := []byte(`{"query":"subscription { receiveData }"}`) + payloadQuery := []byte(`{"query":"query { hello }"}`) + + executorMockQuery := NewMockExecutor(ctrl) + executorMockSubscription := NewMockExecutor(ctrl) + executorMockSubscription.EXPECT().OperationType(). + Return(ast.OperationTypeSubscription). + Times(1) + executorMockSubscription.EXPECT().SetContext(assignableToContextWithCancel(ctx)). + Times(1) + executorMockSubscription.EXPECT().Execute(gomock.AssignableToTypeOf(&graphql.EngineResultWriter{})). + Do(func(resultWriter *graphql.EngineResultWriter) { + _, _ = resultWriter.Write([]byte(`{ "data": { "receiveData": "newData" } }`)) + }). + Times(1) + + executorPoolMock := NewMockExecutorPool(ctrl) + executorPoolMock.EXPECT().Get(gomock.Eq(payloadSubscription)). + Return(executorMockSubscription, nil). + Times(1) + executorPoolMock.EXPECT().Get(gomock.Eq(payloadQuery)). + Return(executorMockQuery, nil). + Times(1) + executorPoolMock.EXPECT().Put(gomock.Eq(executorMockSubscription)). + Do(func(_ Executor) { + wg.Done() + }). + Times(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnDuplicatedSubscriberID), gomock.Eq(id), gomock.Nil(), gomock.Any()). + Times(1) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnSubscriptionData), gomock.Eq(id), gomock.AssignableToTypeOf([]byte{}), gomock.Nil()). + Times(1) + + engine := ExecutorEngine{ + logger: abstractlogger.Noop{}, + subCancellations: subscriptionCancellations{}, + executorPool: executorPoolMock, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + subscriptionUpdateInterval: 100 * time.Millisecond, + } + + assert.Eventually(t, func() bool { + err := engine.StartOperation(ctx, id, payloadSubscription, eventHandlerMock) + assert.NoError(t, err) + + err = engine.StartOperation(ctx, id, payloadQuery, eventHandlerMock) + assert.Error(t, err) + + <-ctx.Done() + wg.Wait() + return true + }, 1*time.Second, 10*time.Millisecond) + }) +} + +func TestExecutorEngine_StopSubscription(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + + id := "1" + payload := []byte(`{"query":"subscription { receiveData }"}`) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnSubscriptionCompleted), gomock.Eq(id), gomock.Nil(), gomock.Nil()). + Times(1) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnSubscriptionData), gomock.Eq(id), gomock.AssignableToTypeOf([]byte{}), gomock.Nil()). + MinTimes(1) + + executorMock := NewMockExecutor(ctrl) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeSubscription). + Times(1) + executorMock.EXPECT().SetContext(assignableToContextWithCancel(ctx)). + Times(1) + executorMock.EXPECT().Execute(gomock.AssignableToTypeOf(&graphql.EngineResultWriter{})). + Do(func(resultWriter *graphql.EngineResultWriter) { + _, _ = resultWriter.Write([]byte(`{ "data": { "receiveData": "newData" } }`)) + }). + MinTimes(1) + + executorPoolMock := NewMockExecutorPool(ctrl) + executorPoolMock.EXPECT().Get(gomock.Eq(payload)). + Return(executorMock, nil). + Times(1) + executorPoolMock.EXPECT().Put(gomock.Eq(executorMock)). + Do(func(_ Executor) { + wg.Done() + }). + Times(1) + + engine := ExecutorEngine{ + logger: abstractlogger.Noop{}, + subCancellations: subscriptionCancellations{}, + executorPool: executorPoolMock, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + subscriptionUpdateInterval: 2 * time.Millisecond, + } + + assert.Eventually(t, func() bool { + err := engine.StartOperation(ctx, id, payload, eventHandlerMock) + assert.NoError(t, err) + assert.Equal(t, 1, engine.subCancellations.Len()) + time.Sleep(5 * time.Millisecond) + + err = engine.StopSubscription(id, eventHandlerMock) + assert.NoError(t, err) + assert.Equal(t, 0, engine.subCancellations.Len()) + wg.Wait() + + return true + }, 1*time.Second, 5*time.Millisecond) +} + +func TestExecutorEngine_TerminateAllConnections(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(3) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + + payload := []byte(`{"query":"subscription { receiveData }"}`) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnConnectionTerminatedByServer), gomock.Eq(""), gomock.Eq([]byte("connection terminated by server")), gomock.Nil()). + Times(1) + eventHandlerMock.EXPECT().Emit(gomock.Eq(EventTypeOnSubscriptionData), gomock.Any(), gomock.AssignableToTypeOf([]byte{}), gomock.Nil()). + MinTimes(3) + + executorMock := NewMockExecutor(ctrl) + executorMock.EXPECT().OperationType(). + Return(ast.OperationTypeSubscription). + Times(3) + executorMock.EXPECT().SetContext(assignableToContextWithCancel(ctx)). + Times(3) + executorMock.EXPECT().Execute(gomock.AssignableToTypeOf(&graphql.EngineResultWriter{})). + Do(func(resultWriter *graphql.EngineResultWriter) { + _, _ = resultWriter.Write([]byte(`{ "data": { "receiveData": "newData" } }`)) + }). + MinTimes(3) + + executorPoolMock := NewMockExecutorPool(ctrl) + executorPoolMock.EXPECT().Get(gomock.Eq(payload)). + Return(executorMock, nil). + Times(3) + executorPoolMock.EXPECT().Put(gomock.Eq(executorMock)). + Do(func(_ Executor) { + wg.Done() + }). + Times(3) + + engine := ExecutorEngine{ + logger: abstractlogger.Noop{}, + subCancellations: subscriptionCancellations{}, + executorPool: executorPoolMock, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + subscriptionUpdateInterval: 2 * time.Millisecond, + } + + assert.Eventually(t, func() bool { + err := engine.StartOperation(ctx, "1", payload, eventHandlerMock) + assert.NoError(t, err) + err = engine.StartOperation(ctx, "2", payload, eventHandlerMock) + assert.NoError(t, err) + err = engine.StartOperation(ctx, "3", payload, eventHandlerMock) + assert.NoError(t, err) + assert.Equal(t, 3, engine.subCancellations.Len()) + time.Sleep(5 * time.Millisecond) + + err = engine.TerminateAllSubscriptions(eventHandlerMock) + assert.NoError(t, err) + assert.Equal(t, 0, engine.subCancellations.Len()) + wg.Wait() + + return true + }, 1*time.Second, 5*time.Millisecond) +} + +func assignableToContextWithCancel(ctx context.Context) gomock.Matcher { + ctxWithCancel, _ := context.WithCancel(ctx) //nolint:govet + return gomock.AssignableToTypeOf(ctxWithCancel) +} diff --git a/v2/pkg/subscription/executor.go b/v2/pkg/subscription/executor.go new file mode 100644 index 000000000..81c2ae2dc --- /dev/null +++ b/v2/pkg/subscription/executor.go @@ -0,0 +1,24 @@ +package subscription + +//go:generate mockgen -destination=executor_mock_test.go -package=subscription . Executor,ExecutorPool + +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/pkg/ast" + "github.com/wundergraph/graphql-go-tools/pkg/engine/resolve" +) + +// Executor is an abstraction for executing a GraphQL engine +type Executor interface { + Execute(writer resolve.FlushWriter) error + OperationType() ast.OperationType + SetContext(context context.Context) + Reset() +} + +// ExecutorPool is an abstraction for creating executors +type ExecutorPool interface { + Get(payload []byte) (Executor, error) + Put(executor Executor) error +} diff --git a/v2/pkg/subscription/executor_mock_test.go b/v2/pkg/subscription/executor_mock_test.go new file mode 100644 index 000000000..6d7716678 --- /dev/null +++ b/v2/pkg/subscription/executor_mock_test.go @@ -0,0 +1,141 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Executor,ExecutorPool) + +// Package subscription is a generated GoMock package. +package subscription + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + ast "github.com/wundergraph/graphql-go-tools/pkg/ast" + resolve "github.com/wundergraph/graphql-go-tools/pkg/engine/resolve" +) + +// MockExecutor is a mock of Executor interface. +type MockExecutor struct { + ctrl *gomock.Controller + recorder *MockExecutorMockRecorder +} + +// MockExecutorMockRecorder is the mock recorder for MockExecutor. +type MockExecutorMockRecorder struct { + mock *MockExecutor +} + +// NewMockExecutor creates a new mock instance. +func NewMockExecutor(ctrl *gomock.Controller) *MockExecutor { + mock := &MockExecutor{ctrl: ctrl} + mock.recorder = &MockExecutorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockExecutor) EXPECT() *MockExecutorMockRecorder { + return m.recorder +} + +// Execute mocks base method. +func (m *MockExecutor) Execute(arg0 resolve.FlushWriter) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Execute", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Execute indicates an expected call of Execute. +func (mr *MockExecutorMockRecorder) Execute(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockExecutor)(nil).Execute), arg0) +} + +// OperationType mocks base method. +func (m *MockExecutor) OperationType() ast.OperationType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OperationType") + ret0, _ := ret[0].(ast.OperationType) + return ret0 +} + +// OperationType indicates an expected call of OperationType. +func (mr *MockExecutorMockRecorder) OperationType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OperationType", reflect.TypeOf((*MockExecutor)(nil).OperationType)) +} + +// Reset mocks base method. +func (m *MockExecutor) Reset() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Reset") +} + +// Reset indicates an expected call of Reset. +func (mr *MockExecutorMockRecorder) Reset() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockExecutor)(nil).Reset)) +} + +// SetContext mocks base method. +func (m *MockExecutor) SetContext(arg0 context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetContext", arg0) +} + +// SetContext indicates an expected call of SetContext. +func (mr *MockExecutorMockRecorder) SetContext(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetContext", reflect.TypeOf((*MockExecutor)(nil).SetContext), arg0) +} + +// MockExecutorPool is a mock of ExecutorPool interface. +type MockExecutorPool struct { + ctrl *gomock.Controller + recorder *MockExecutorPoolMockRecorder +} + +// MockExecutorPoolMockRecorder is the mock recorder for MockExecutorPool. +type MockExecutorPoolMockRecorder struct { + mock *MockExecutorPool +} + +// NewMockExecutorPool creates a new mock instance. +func NewMockExecutorPool(ctrl *gomock.Controller) *MockExecutorPool { + mock := &MockExecutorPool{ctrl: ctrl} + mock.recorder = &MockExecutorPoolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockExecutorPool) EXPECT() *MockExecutorPoolMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockExecutorPool) Get(arg0 []byte) (Executor, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(Executor) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockExecutorPoolMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockExecutorPool)(nil).Get), arg0) +} + +// Put mocks base method. +func (m *MockExecutorPool) Put(arg0 Executor) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Put", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Put indicates an expected call of Put. +func (mr *MockExecutorPoolMockRecorder) Put(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockExecutorPool)(nil).Put), arg0) +} diff --git a/v2/pkg/subscription/handler.go b/v2/pkg/subscription/handler.go index e9c15c244..38d8317c0 100644 --- a/v2/pkg/subscription/handler.go +++ b/v2/pkg/subscription/handler.go @@ -1,508 +1,210 @@ package subscription +//go:generate mockgen -destination=handler_mock_test.go -package=subscription . Protocol,EventHandler + import ( "bytes" "context" - "encoding/json" + "errors" "sync" "time" "github.com/jensneuse/abstractlogger" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" ) -const ( - MessageTypeConnectionInit = "connection_init" - MessageTypeConnectionAck = "connection_ack" - MessageTypeConnectionError = "connection_error" - MessageTypeConnectionTerminate = "connection_terminate" - MessageTypeConnectionKeepAlive = "ka" - MessageTypeStart = "start" - MessageTypeStop = "stop" - MessageTypeData = "data" - MessageTypeError = "error" - MessageTypeComplete = "complete" +var ErrCouldNotReadMessageFromClient = errors.New("could not read message from client") + +// EventType can be used to define subscription events decoupled from any protocols. +type EventType int - DefaultKeepAliveInterval = "15s" - DefaultSubscriptionUpdateInterval = "1s" +const ( + EventTypeOnError EventType = iota + EventTypeOnSubscriptionData + EventTypeOnSubscriptionCompleted + EventTypeOnNonSubscriptionExecutionResult + EventTypeOnConnectionTerminatedByClient + EventTypeOnConnectionTerminatedByServer + EventTypeOnConnectionError + EventTypeOnConnectionOpened + EventTypeOnDuplicatedSubscriberID ) -// Message defines the actual subscription message wich will be passed from client to server and vice versa. -type Message struct { - Id string `json:"id"` - Type string `json:"type"` - Payload json.RawMessage `json:"payload"` +// Protocol defines an interface for a subscription protocol decoupled from the underlying transport. +type Protocol interface { + Handle(ctx context.Context, engine Engine, message []byte) error + EventHandler() EventHandler } -// client provides an interface which can be implemented by any possible subscription client like websockets, mqtt, etc. -type Client interface { - // ReadFromClient will invoke a read operation from the client connection. - ReadFromClient() (*Message, error) - // WriteToClient will invoke a write operation to the client connection. - WriteToClient(Message) error - // IsConnected will indicate if a connection is still established. - IsConnected() bool - // Disconnect will close the connection between server and client. - Disconnect() error +// EventHandler is an interface that handles subscription events. +type EventHandler interface { + Emit(eventType EventType, id string, data []byte, err error) } -// ExecutorPool is an abstraction for creating executors -type ExecutorPool interface { - Get(payload []byte) (Executor, error) - Put(executor Executor) error +// UniversalProtocolHandlerOptions is struct that defines options for the UniversalProtocolHandler. +type UniversalProtocolHandlerOptions struct { + Logger abstractlogger.Logger + CustomSubscriptionUpdateInterval time.Duration + CustomReadErrorTimeOut time.Duration + CustomEngine Engine } -// Executor is an abstraction for executing a GraphQL engine -type Executor interface { - Execute(writer resolve.FlushWriter) error - OperationType() ast.OperationType - SetContext(context context.Context) - Reset() +// UniversalProtocolHandler can handle any protocol by using the Protocol interface. +type UniversalProtocolHandler struct { + logger abstractlogger.Logger + client TransportClient + protocol Protocol + engine Engine + readErrorTimeOut time.Duration + isReadTimeOutTimerRunning bool + readTimeOutCancel context.CancelFunc } -// WebsocketInitFunc is called when the server receives connection init message from the client. -// This can be used to check initial payload to see whether to accept the websocket connection. -type WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) +// NewUniversalProtocolHandler creates a new UniversalProtocolHandler. +func NewUniversalProtocolHandler(client TransportClient, protocol Protocol, executorPool ExecutorPool) (*UniversalProtocolHandler, error) { + options := UniversalProtocolHandlerOptions{ + Logger: abstractlogger.Noop{}, + } -// Handler is the actual subscription handler which will keep track on how to handle messages coming from the client. -type Handler struct { - logger abstractlogger.Logger - // client will hold the subscription client implementation. - client Client - // keepAliveInterval is the actual interval on which the server send keep alive messages to the client. - keepAliveInterval time.Duration - // subscriptionUpdateInterval is the actual interval on which the server sends subscription updates to the client. - subscriptionUpdateInterval time.Duration - // subCancellations stores a map containing the cancellation functions to every active subscription. - subCancellations subscriptionCancellations - // executorPool is responsible to create and hold executors. - executorPool ExecutorPool - // bufferPool will hold buffers. - bufferPool *sync.Pool - // initFunc will check initial payload to see whether to accept the websocket connection. - initFunc WebsocketInitFunc + return NewUniversalProtocolHandlerWithOptions(client, protocol, executorPool, options) } -func NewHandlerWithInitFunc( - logger abstractlogger.Logger, - client Client, - executorPool ExecutorPool, - initFunc WebsocketInitFunc, -) (*Handler, error) { - keepAliveInterval, err := time.ParseDuration(DefaultKeepAliveInterval) - if err != nil { - return nil, err +// NewUniversalProtocolHandlerWithOptions creates a new UniversalProtocolHandler. It requires an option struct. +func NewUniversalProtocolHandlerWithOptions(client TransportClient, protocol Protocol, executorPool ExecutorPool, options UniversalProtocolHandlerOptions) (*UniversalProtocolHandler, error) { + handler := UniversalProtocolHandler{ + logger: abstractlogger.Noop{}, + client: client, + protocol: protocol, } - subscriptionUpdateInterval, err := time.ParseDuration(DefaultSubscriptionUpdateInterval) - if err != nil { - return nil, err + if options.Logger != nil { + handler.logger = options.Logger } - return &Handler{ - logger: logger, - client: client, - keepAliveInterval: keepAliveInterval, - subscriptionUpdateInterval: subscriptionUpdateInterval, - subCancellations: subscriptionCancellations{}, - executorPool: executorPool, - bufferPool: &sync.Pool{ - New: func() interface{} { - writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) - return &writer - }, - }, - initFunc: initFunc, - }, nil -} - -// NewHandler creates a new subscription handler. -func NewHandler(logger abstractlogger.Logger, client Client, executorPool ExecutorPool) (*Handler, error) { - return NewHandlerWithInitFunc(logger, client, executorPool, nil) -} - -// Handle will handle the subscription connection. -func (h *Handler) Handle(ctx context.Context) { - defer h.subCancellations.CancelAll() - - for { - if !h.client.IsConnected() { - h.logger.Debug("subscription.Handler.Handle()", - abstractlogger.String("message", "client has disconnected"), - ) - - return - } - - message, err := h.client.ReadFromClient() + if options.CustomReadErrorTimeOut != 0 { + handler.readErrorTimeOut = options.CustomReadErrorTimeOut + } else { + parsedReadErrorTimeOut, err := time.ParseDuration(DefaultReadErrorTimeOut) if err != nil { - h.logger.Error("subscription.Handler.Handle()", - abstractlogger.Error(err), - abstractlogger.Any("message", message), - ) - - h.handleConnectionError("could not read message from client") - } else if message != nil { - switch message.Type { - case MessageTypeConnectionInit: - ctx, err = h.handleInit(ctx, message.Payload) - if err != nil { - h.terminateConnection("failed to accept the websocket connection") - return - } - - go h.handleKeepAlive(ctx) - case MessageTypeStart: - h.handleStart(ctx, message.Id, message.Payload) - case MessageTypeStop: - h.handleStop(message.Id) - case MessageTypeConnectionTerminate: - h.handleConnectionTerminate() - return - } - } - - select { - case <-ctx.Done(): - return - default: - continue + return nil, err } + handler.readErrorTimeOut = parsedReadErrorTimeOut } -} - -// ChangeKeepAliveInterval can be used to change the keep alive interval. -func (h *Handler) ChangeKeepAliveInterval(d time.Duration) { - h.keepAliveInterval = d -} - -// ChangeSubscriptionUpdateInterval can be used to change the update interval. -func (h *Handler) ChangeSubscriptionUpdateInterval(d time.Duration) { - h.subscriptionUpdateInterval = d -} -// handleInit will handle an init message. -func (h *Handler) handleInit(ctx context.Context, payload []byte) (extendedCtx context.Context, err error) { - if h.initFunc != nil { - var initPayload InitPayload - // decode initial payload - if len(payload) > 0 { - initPayload = payload - } - // check initial payload to see whether to accept the websocket connection - if extendedCtx, err = h.initFunc(ctx, initPayload); err != nil { - return extendedCtx, err - } + if options.CustomEngine != nil { + handler.engine = options.CustomEngine } else { - extendedCtx = ctx - } - - ackMessage := Message{ - Type: MessageTypeConnectionAck, - } - - if err = h.client.WriteToClient(ackMessage); err != nil { - return extendedCtx, err - } - - return extendedCtx, nil -} - -// handleStart will handle s start message. -func (h *Handler) handleStart(ctx context.Context, id string, payload []byte) { - executor, err := h.executorPool.Get(payload) - if err != nil { - h.logger.Error("subscription.Handler.handleStart()", - abstractlogger.Error(err), - ) - - h.handleError(id, graphql.RequestErrorsFromError(err)) - return - } - - if err = h.handleOnBeforeStart(executor); err != nil { - h.handleError(id, graphql.RequestErrorsFromError(err)) - return - } - - if executor.OperationType() == ast.OperationTypeSubscription { - ctx := h.subCancellations.AddWithParent(id, ctx) - go h.startSubscription(ctx, id, executor) - return - } - - go h.handleNonSubscriptionOperation(ctx, id, executor) -} - -func (h *Handler) handleOnBeforeStart(executor Executor) error { - switch e := executor.(type) { - case *ExecutorV2: - if hook := e.engine.GetWebsocketBeforeStartHook(); hook != nil { - return hook.OnBeforeStart(e.reqCtx, e.operation) + engine := ExecutorEngine{ + logger: handler.logger, + subCancellations: subscriptionCancellations{}, + executorPool: executorPool, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, } - } - return nil -} - -// handleNonSubscriptionOperation will handle a non-subscription operation like a query or a mutation. -func (h *Handler) handleNonSubscriptionOperation(ctx context.Context, id string, executor Executor) { - defer func() { - err := h.executorPool.Put(executor) - if err != nil { - h.logger.Error("subscription.Handle.handleNonSubscriptionOperation()", - abstractlogger.Error(err), - ) + if options.CustomSubscriptionUpdateInterval != 0 { + engine.subscriptionUpdateInterval = options.CustomSubscriptionUpdateInterval + } else { + subscriptionUpdateInterval, err := time.ParseDuration(DefaultSubscriptionUpdateInterval) + if err != nil { + return nil, err + } + engine.subscriptionUpdateInterval = subscriptionUpdateInterval } - }() - - executor.SetContext(ctx) - buf := h.bufferPool.Get().(*graphql.EngineResultWriter) - buf.Reset() - - defer h.bufferPool.Put(buf) - - // err := executor.Execute(executionContext, node, buf) - err := executor.Execute(buf) - if err != nil { - h.logger.Error("subscription.Handle.handleNonSubscriptionOperation()", - abstractlogger.Error(err), - ) - - h.handleError(id, graphql.RequestErrorsFromError(err)) - return + handler.engine = &engine } - h.logger.Debug("subscription.Handle.handleNonSubscriptionOperation()", - abstractlogger.ByteString("execution_result", buf.Bytes()), - ) - - h.sendData(id, buf.Bytes()) - h.sendComplete(id) + return &handler, nil } -// startSubscription will invoke the actual subscription. -func (h *Handler) startSubscription(ctx context.Context, id string, executor Executor) { +// Handle will handle the subscription logic and forward messages to the actual protocol handler. +func (u *UniversalProtocolHandler) Handle(ctx context.Context) { + ctxWithCancel, cancel := context.WithCancel(ctx) defer func() { - err := h.executorPool.Put(executor) + err := u.engine.TerminateAllSubscriptions(u.protocol.EventHandler()) if err != nil { - h.logger.Error("subscription.Handle.startSubscription()", + u.logger.Error("subscription.UniversalProtocolHandler.Handle: on terminate connections", abstractlogger.Error(err), ) } + cancel() }() - executor.SetContext(ctx) - buf := h.bufferPool.Get().(*graphql.EngineResultWriter) - buf.Reset() - - defer h.bufferPool.Put(buf) - - h.executeSubscription(buf, id, executor) + u.protocol.EventHandler().Emit(EventTypeOnConnectionOpened, "", nil, nil) for { - buf.Reset() - select { - case <-ctx.Done(): - return - case <-time.After(h.subscriptionUpdateInterval): - h.executeSubscription(buf, id, executor) - } - } - -} - -// executeSubscription will keep execution the subscription until it ends. -func (h *Handler) executeSubscription(buf *graphql.EngineResultWriter, id string, executor Executor) { - buf.SetFlushCallback(func(data []byte) { - h.logger.Debug("subscription.Handle.executeSubscription()", - abstractlogger.ByteString("execution_result", data), - ) - h.sendData(id, data) - }) - defer buf.SetFlushCallback(nil) - - err := executor.Execute(buf) - if err != nil { - h.logger.Error("subscription.Handle.executeSubscription()", - abstractlogger.Error(err), - ) - - h.handleError(id, graphql.RequestErrorsFromError(err)) - return - } - - if buf.Len() > 0 { - data := buf.Bytes() - h.logger.Debug("subscription.Handle.executeSubscription()", - abstractlogger.ByteString("execution_result", data), - ) - h.sendData(id, data) - } -} - -// handleStop will handle a stop message, -func (h *Handler) handleStop(id string) { - h.subCancellations.Cancel(id) - h.sendComplete(id) -} - -// sendData will send a data message to the client. -func (h *Handler) sendData(id string, responseData []byte) { - dataMessage := Message{ - Id: id, - Type: MessageTypeData, - Payload: responseData, - } - - err := h.client.WriteToClient(dataMessage) - if err != nil { - h.logger.Error("subscription.Handler.sendData()", - abstractlogger.Error(err), - ) - } -} - -// nolint -// sendComplete will send a complete message to the client. -func (h *Handler) sendComplete(id string) { - completeMessage := Message{ - Id: id, - Type: MessageTypeComplete, - Payload: nil, - } - - err := h.client.WriteToClient(completeMessage) - if err != nil { - h.logger.Error("subscription.Handler.sendComplete()", - abstractlogger.Error(err), - ) - } -} - -// handleConnectionTerminate will handle a comnnection terminate message. -func (h *Handler) handleConnectionTerminate() { - err := h.client.Disconnect() - if err != nil { - h.logger.Error("subscription.Handler.handleConnectionTerminate()", - abstractlogger.Error(err), - ) - } -} + if !u.client.IsConnected() { + u.logger.Debug("subscription.UniversalProtocolHandler.Handle: on client is connected check", + abstractlogger.String("message", "client has disconnected"), + ) -// handleKeepAlive will handle the keep alive loop. -func (h *Handler) handleKeepAlive(ctx context.Context) { - for { - select { - case <-ctx.Done(): return - case <-time.After(h.keepAliveInterval): - h.sendKeepAlive() } - } -} -// sendKeepAlive will send a keep alive message to the client. -func (h *Handler) sendKeepAlive() { - keepAliveMessage := Message{ - Type: MessageTypeConnectionKeepAlive, - } - - err := h.client.WriteToClient(keepAliveMessage) - if err != nil { - h.logger.Error("subscription.Handler.sendKeepAlive()", - abstractlogger.Error(err), - ) - } -} - -func (h *Handler) terminateConnection(reason interface{}) { - payloadBytes, err := json.Marshal(reason) - if err != nil { - h.logger.Error("subscription.Handler.terminateConnection()", - abstractlogger.Error(err), - abstractlogger.Any("errorPayload", reason), - ) - } - - connectionErrorMessage := Message{ - Type: MessageTypeConnectionTerminate, - Payload: payloadBytes, - } - - err = h.client.WriteToClient(connectionErrorMessage) - if err != nil { - h.logger.Error("subscription.Handler.terminateConnection()", - abstractlogger.Error(err), - ) - - err := h.client.Disconnect() - if err != nil { - h.logger.Error("subscription.Handler.terminateConnection()", + message, err := u.client.ReadBytesFromClient() + if errors.Is(err, ErrTransportClientClosedConnection) { + u.logger.Debug("subscription.UniversalProtocolHandler.Handle: reading from a closed connection") + return + } else if err != nil { + u.logger.Error("subscription.UniversalProtocolHandler.Handle: on reading bytes from client", abstractlogger.Error(err), + abstractlogger.ByteString("message", message), ) - } - } -} -// handleConnectionError will handle a connection error message. -func (h *Handler) handleConnectionError(errorPayload interface{}) { - payloadBytes, err := json.Marshal(errorPayload) - if err != nil { - h.logger.Error("subscription.Handler.handleConnectionError()", - abstractlogger.Error(err), - abstractlogger.Any("errorPayload", errorPayload), - ) - } - - connectionErrorMessage := Message{ - Type: MessageTypeConnectionError, - Payload: payloadBytes, - } + if !u.isReadTimeOutTimerRunning { + var timeOutCtx context.Context + timeOutCtx, u.readTimeOutCancel = context.WithCancel(context.Background()) + params := TimeOutParams{ + Name: "subscription reader error time out", + Logger: u.logger, + TimeOutContext: timeOutCtx, + TimeOutAction: func() { + cancel() // stop the handler if timer runs out + }, + TimeOutDuration: u.readErrorTimeOut, + } + go TimeOutChecker(params) + u.isReadTimeOutTimerRunning = true + } - err = h.client.WriteToClient(connectionErrorMessage) - if err != nil { - h.logger.Error("subscription.Handler.handleConnectionError()", - abstractlogger.Error(err), - ) + u.protocol.EventHandler().Emit(EventTypeOnConnectionError, "", nil, ErrCouldNotReadMessageFromClient) + } else { + if u.isReadTimeOutTimerRunning && u.readTimeOutCancel != nil { + u.readTimeOutCancel() + u.isReadTimeOutTimerRunning = false + u.readTimeOutCancel = nil + } - err := h.client.Disconnect() - if err != nil { - h.logger.Error("subscription.Handler.handleError()", - abstractlogger.Error(err), - ) + if len(message) > 0 { + err := u.protocol.Handle(ctxWithCancel, u.engine, message) + if err != nil { + var onBeforeStartHookError *errOnBeforeStartHookFailure + if errors.As(err, &onBeforeStartHookError) { + // if we do have an errOnBeforeStartHookFailure than the error is expected and should be + // logged as 'Debug'. + u.logger.Debug("subscription.UniversalProtocolHandler.Handle: on protocol handling message", + abstractlogger.Error(err), + ) + } else { + // all other errors should be treated as unexpected and therefore being logged as 'Error'. + u.logger.Error("subscription.UniversalProtocolHandler.Handle: on protocol handling message", + abstractlogger.Error(err), + ) + } + } + } } - } -} - -// handleError will handle an error message. -func (h *Handler) handleError(id string, errors graphql.RequestErrors) { - payloadBytes, err := json.Marshal(errors) - if err != nil { - h.logger.Error("subscription.Handler.handleError()", - abstractlogger.Error(err), - abstractlogger.Any("errors", errors), - ) - } - errorMessage := Message{ - Id: id, - Type: MessageTypeError, - Payload: payloadBytes, - } - - err = h.client.WriteToClient(errorMessage) - if err != nil { - h.logger.Error("subscription.Handler.handleError()", - abstractlogger.Error(err), - ) + select { + case <-ctxWithCancel.Done(): + return + default: + continue + } } } - -// ActiveSubscriptions will return the actual number of active subscriptions for that client. -func (h *Handler) ActiveSubscriptions() int { - return h.subCancellations.Len() -} diff --git a/v2/pkg/subscription/handler_mock_test.go b/v2/pkg/subscription/handler_mock_test.go new file mode 100644 index 000000000..f9124b83c --- /dev/null +++ b/v2/pkg/subscription/handler_mock_test.go @@ -0,0 +1,98 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Protocol,EventHandler) + +// Package subscription is a generated GoMock package. +package subscription + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockProtocol is a mock of Protocol interface. +type MockProtocol struct { + ctrl *gomock.Controller + recorder *MockProtocolMockRecorder +} + +// MockProtocolMockRecorder is the mock recorder for MockProtocol. +type MockProtocolMockRecorder struct { + mock *MockProtocol +} + +// NewMockProtocol creates a new mock instance. +func NewMockProtocol(ctrl *gomock.Controller) *MockProtocol { + mock := &MockProtocol{ctrl: ctrl} + mock.recorder = &MockProtocolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockProtocol) EXPECT() *MockProtocolMockRecorder { + return m.recorder +} + +// EventHandler mocks base method. +func (m *MockProtocol) EventHandler() EventHandler { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EventHandler") + ret0, _ := ret[0].(EventHandler) + return ret0 +} + +// EventHandler indicates an expected call of EventHandler. +func (mr *MockProtocolMockRecorder) EventHandler() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EventHandler", reflect.TypeOf((*MockProtocol)(nil).EventHandler)) +} + +// Handle mocks base method. +func (m *MockProtocol) Handle(arg0 context.Context, arg1 Engine, arg2 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Handle", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Handle indicates an expected call of Handle. +func (mr *MockProtocolMockRecorder) Handle(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handle", reflect.TypeOf((*MockProtocol)(nil).Handle), arg0, arg1, arg2) +} + +// MockEventHandler is a mock of EventHandler interface. +type MockEventHandler struct { + ctrl *gomock.Controller + recorder *MockEventHandlerMockRecorder +} + +// MockEventHandlerMockRecorder is the mock recorder for MockEventHandler. +type MockEventHandlerMockRecorder struct { + mock *MockEventHandler +} + +// NewMockEventHandler creates a new mock instance. +func NewMockEventHandler(ctrl *gomock.Controller) *MockEventHandler { + mock := &MockEventHandler{ctrl: ctrl} + mock.recorder = &MockEventHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEventHandler) EXPECT() *MockEventHandlerMockRecorder { + return m.recorder +} + +// Emit mocks base method. +func (m *MockEventHandler) Emit(arg0 EventType, arg1 string, arg2 []byte, arg3 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Emit", arg0, arg1, arg2, arg3) +} + +// Emit indicates an expected call of Emit. +func (mr *MockEventHandlerMockRecorder) Emit(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockEventHandler)(nil).Emit), arg0, arg1, arg2, arg3) +} diff --git a/v2/pkg/subscription/handler_test.go b/v2/pkg/subscription/handler_test.go index d21e3cb90..125cee516 100644 --- a/v2/pkg/subscription/handler_test.go +++ b/v2/pkg/subscription/handler_test.go @@ -1,624 +1,338 @@ package subscription import ( - "bytes" "context" - "encoding/json" "errors" - "fmt" - "net/http" - "net/http/httptest" + "sync" "testing" "time" + "github.com/golang/mock/gomock" "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/starwars" - "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/subscriptiontesting" ) -type handlerRoutine func(ctx context.Context) func() bool - -type websocketHook struct { - called bool - reqCtx context.Context - hook func(reqCtx context.Context, operation *graphql.Request) error -} - -func (w *websocketHook) OnBeforeStart(reqCtx context.Context, operation *graphql.Request) error { - w.called = true - if w.hook != nil { - return w.hook(reqCtx, operation) - } - return nil -} - -func TestHandler_Handle(t *testing.T) { - starwars.SetRelativePathToStarWarsPackage("../starwars") - - t.Run("engine v2", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - chatServer := httptest.NewServer(subscriptiontesting.ChatGraphQLEndpointHandler()) - defer chatServer.Close() - - t.Run("connection_init", func(t *testing.T) { - var initPayloadAuthorization string - - executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) - _, client, handlerRoutine := setupSubscriptionHandlerWithInitFuncTest(t, executorPool, func(ctx context.Context, initPayload InitPayload) (context.Context, error) { - if initPayloadAuthorization == "" { - return ctx, nil - } - - if initPayloadAuthorization != initPayload.Authorization() { - return nil, fmt.Errorf("unknown user: %s", initPayload.Authorization()) - } - - return ctx, nil - }) - - t.Run("should send connection error message when error on read occurrs", func(t *testing.T) { - client.prepareConnectionInitMessage().withError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - - cancelFunc() - require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) +func TestUniversalProtocolHandler_Handle(t *testing.T) { + t.Run("should terminate when client is disconnected", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) - expectedMessage := Message{ - Type: MessageTypeConnectionError, - Payload: jsonizePayload(t, "could not read message from client"), - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedMessage) - }) - - t.Run("should successfully init connection and respond with ack", func(t *testing.T) { - client.reconnect().and().prepareConnectionInitMessage().withoutError().and().send() + ctrl := gomock.NewController(t) + defer ctrl.Finish() - ctx, cancelFunc := context.WithCancel(context.Background()) + clientMock := NewMockTransportClient(ctrl) + clientMock.EXPECT().IsConnected(). + Return(false). + Times(1) - cancelFunc() - require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) - - expectedMessage := Message{ - Type: MessageTypeConnectionAck, - } + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionOpened, gomock.Eq(""), gomock.Nil(), gomock.Nil()) - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedMessage) - }) + protocolMock := NewMockProtocol(ctrl) + protocolMock.EXPECT().EventHandler(). + Return(eventHandlerMock). + Times(2) - t.Run("should send connection error message when error on check initial payload occurrs", func(t *testing.T) { - initPayloadAuthorization = "123" - defer func() { initPayloadAuthorization = "" }() + engineMock := NewMockEngine(ctrl) + engineMock.EXPECT().TerminateAllSubscriptions(eventHandlerMock). + Do(func(_ EventHandler) { + wg.Done() + }). + Times(1) - client.reconnect().and().prepareConnectionInitMessageWithPayload([]byte(`{"Authorization": "111"}`)).withoutError().and().send() + ctx, cancelFunc := context.WithCancel(context.Background()) - ctx, cancelFunc := context.WithCancel(context.Background()) - - cancelFunc() - require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) - - expectedMessage := Message{ - Type: MessageTypeConnectionTerminate, - Payload: jsonizePayload(t, "failed to accept the websocket connection"), - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedMessage) - }) - - t.Run("should successfully init connection and respond with ack when initial payload successfully occurred ", func(t *testing.T) { - initPayloadAuthorization = "123" - defer func() { initPayloadAuthorization = "" }() - - client.reconnect().and().prepareConnectionInitMessageWithPayload([]byte(`{"Authorization": "123"}`)).withoutError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - - cancelFunc() - require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) - - expectedMessage := Message{ - Type: MessageTypeConnectionAck, - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedMessage) - }) - }) - - t.Run("connection_keep_alive", func(t *testing.T) { - executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - - t.Run("should successfully send keep alive messages after connection_init", func(t *testing.T) { - keepAliveInterval, err := time.ParseDuration("5ms") - require.NoError(t, err) - - subscriptionHandler.ChangeKeepAliveInterval(keepAliveInterval) - - client.prepareConnectionInitMessage().withoutError().and().send() - ctx, cancelFunc := context.WithCancel(context.Background()) - - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() - - expectedMessage := Message{ - Type: MessageTypeConnectionKeepAlive, - } - - messagesFromServer := client.readFromServer() - waitForKeepAliveMessage := func() bool { - for len(messagesFromServer) < 2 { - messagesFromServer = client.readFromServer() - } - return true - } - - assert.Eventually(t, waitForKeepAliveMessage, 1*time.Second, 5*time.Millisecond) - assert.Contains(t, messagesFromServer, expectedMessage) - - cancelFunc() - }) - }) - - t.Run("erroneous operation(s)", func(t *testing.T) { - executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) - _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - ctx, cancelFunc := context.WithCancel(context.Background()) - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() - - t.Run("should send error when query contains syntax errors", func(t *testing.T) { - payload := []byte(`{"operationName": "Broken", "query Broken {": "", "variables": null}`) - client.prepareStartMessage("1", payload).withoutError().send() - - waitForClientHavingAMessage := func() bool { - return client.hasMoreMessagesThan(0) - } - require.Eventually(t, waitForClientHavingAMessage, 5*time.Second, 5*time.Millisecond) - - expectedMessage := Message{ - Id: "1", - Type: MessageTypeError, - Payload: []byte(`[{"message":"document doesn't contain any executable operation"}]`), - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedMessage) - }) + options := UniversalProtocolHandlerOptions{ + Logger: abstractlogger.Noop{}, + CustomSubscriptionUpdateInterval: 0, + CustomEngine: engineMock, + } + handler, err := NewUniversalProtocolHandlerWithOptions(clientMock, protocolMock, nil, options) + require.NoError(t, err) + assert.Eventually(t, func() bool { + go handler.Handle(ctx) + time.Sleep(5 * time.Millisecond) cancelFunc() - }) - - t.Run("non-subscription query", func(t *testing.T) { - executorPool, hookHolder := setupEngineV2(t, ctx, chatServer.URL) - - t.Run("should process query and return error when query is not valid", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - - payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.InvalidOperation) - require.NoError(t, err) - client.prepareStartMessage("1", payload).withoutError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - cancelFunc() - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() - - waitForClientHavingAMessage := func() bool { - return client.hasMoreMessagesThan(0) - } - require.Eventually(t, waitForClientHavingAMessage, 1*time.Second, 5*time.Millisecond) - - expectedErrorMessage := Message{ - Id: "1", - Type: MessageTypeError, - Payload: []byte(`[{"message":"field: serverName not defined on type: Query","path":["query","serverName"]}]`), - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedErrorMessage) - assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) - }) - - t.Run("should process and send result for a query", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - - payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.MutationSendMessage) - require.NoError(t, err) - - hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { - assert.Equal(t, hookHolder.reqCtx, ctx) - assert.Contains(t, operation.Query, "mutation SendMessage") - return nil - } - defer func() { - hookHolder.hook = nil - }() - - client.prepareStartMessage("1", payload).withoutError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() - - waitForClientHavingTwoMessages := func() bool { - return client.hasMoreMessagesThan(1) - } - require.Eventually(t, waitForClientHavingTwoMessages, 60*time.Second, 5*time.Millisecond) - - expectedDataMessage := Message{ - Id: "1", - Type: MessageTypeData, - Payload: []byte(`{"data":{"post":{"text":"Hello World!","createdBy":"myuser"}}}`), - } - - expectedCompleteMessage := Message{ - Id: "1", - Type: MessageTypeComplete, - Payload: nil, - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedDataMessage) - assert.Contains(t, messagesFromServer, expectedCompleteMessage) - assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) - assert.True(t, hookHolder.called) - }) + <-ctx.Done() // Check if channel is closed + wg.Wait() + return true + }, 1*time.Second, 5*time.Millisecond) + }) - t.Run("should process and send error message from hook for a query", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + t.Run("should terminate when reading on closed connection", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + clientMock := NewMockTransportClient(ctrl) + clientMock.EXPECT().IsConnected(). + Return(true). + Times(1) + clientMock.EXPECT().ReadBytesFromClient(). + Return(nil, ErrTransportClientClosedConnection). + Times(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionOpened, gomock.Eq(""), gomock.Nil(), gomock.Nil()) + + protocolMock := NewMockProtocol(ctrl) + protocolMock.EXPECT().EventHandler(). + Return(eventHandlerMock). + Times(2) + + engineMock := NewMockEngine(ctrl) + engineMock.EXPECT().TerminateAllSubscriptions(eventHandlerMock). + Do(func(_ EventHandler) { + wg.Done() + }). + Times(1) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + options := UniversalProtocolHandlerOptions{ + Logger: abstractlogger.Noop{}, + CustomSubscriptionUpdateInterval: 0, + CustomEngine: engineMock, + } + handler, err := NewUniversalProtocolHandlerWithOptions(clientMock, protocolMock, nil, options) + require.NoError(t, err) - payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.MutationSendMessage) - require.NoError(t, err) + assert.Eventually(t, func() bool { + go handler.Handle(ctx) + time.Sleep(5 * time.Millisecond) + cancelFunc() + <-ctx.Done() // Check if channel is closed + wg.Wait() + return true + }, 1*time.Second, 5*time.Millisecond) + }) - errMsg := "error_on_operation" - hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { - return errors.New(errMsg) - } - defer func() { - hookHolder.hook = nil - }() + t.Run("should sent event on client read error", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + clientMock := NewMockTransportClient(ctrl) + clientMock.EXPECT().ReadBytesFromClient(). + Return(nil, errors.New("read error")). + MinTimes(1) + clientMock.EXPECT().IsConnected(). + Return(true). + MinTimes(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionError, gomock.Eq(""), gomock.Nil(), gomock.Eq(ErrCouldNotReadMessageFromClient)). + MinTimes(1) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionOpened, gomock.Eq(""), gomock.Nil(), gomock.Nil()) + + protocolMock := NewMockProtocol(ctrl) + protocolMock.EXPECT().EventHandler(). + Return(eventHandlerMock). + MinTimes(1) + + engineMock := NewMockEngine(ctrl) + engineMock.EXPECT().TerminateAllSubscriptions(eventHandlerMock). + Do(func(_ EventHandler) { + wg.Done() + }). + Times(1) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + options := UniversalProtocolHandlerOptions{ + Logger: abstractlogger.Noop{}, + CustomSubscriptionUpdateInterval: 0, + CustomEngine: engineMock, + } + handler, err := NewUniversalProtocolHandlerWithOptions(clientMock, protocolMock, nil, options) + require.NoError(t, err) - client.prepareStartMessage("1", payload).withoutError().and().send() + assert.Eventually(t, func() bool { + go handler.Handle(ctx) + time.Sleep(5 * time.Millisecond) + cancelFunc() + <-ctx.Done() // Check if channel is closed + wg.Wait() + return true + }, 1*time.Second, 5*time.Millisecond) + }) - ctx, cancelFunc := context.WithCancel(context.Background()) - cancelFunc() - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() + t.Run("should handover message to protocol handler", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + clientMock := NewMockTransportClient(ctrl) + clientMock.EXPECT().ReadBytesFromClient(). + Return([]byte(`{"type":"start","id":"1","payload":"{\"query\":\"{ hello }\”}"}`), nil). + MinTimes(1) + clientMock.EXPECT().IsConnected(). + Return(true). + MinTimes(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionOpened, gomock.Eq(""), gomock.Nil(), gomock.Nil()) + + engineMock := NewMockEngine(ctrl) + engineMock.EXPECT().TerminateAllSubscriptions(eventHandlerMock). + Do(func(_ EventHandler) { + wg.Done() + }). + Times(1) + + protocolMock := NewMockProtocol(ctrl) + protocolMock.EXPECT().EventHandler(). + Return(eventHandlerMock). + Times(2) + protocolMock.EXPECT().Handle(assignableToContextWithCancel(ctx), gomock.Eq(engineMock), gomock.Eq([]byte(`{"type":"start","id":"1","payload":"{\"query\":\"{ hello }\”}"}`))). + Return(nil). + MinTimes(1) + + options := UniversalProtocolHandlerOptions{ + Logger: abstractlogger.Noop{}, + CustomSubscriptionUpdateInterval: 0, + CustomEngine: engineMock, + } + handler, err := NewUniversalProtocolHandlerWithOptions(clientMock, protocolMock, nil, options) + require.NoError(t, err) - waitForClientHavingTwoMessages := func() bool { - return client.hasMoreMessagesThan(0) - } - require.Eventually(t, waitForClientHavingTwoMessages, 5*time.Second, 5*time.Millisecond) - - jsonErrMessage, err := json.Marshal(graphql.RequestErrors{ - {Message: errMsg}, - }) - require.NoError(t, err) - expectedErrMessage := Message{ - Id: "1", - Type: MessageTypeError, - Payload: jsonErrMessage, - } + assert.Eventually(t, func() bool { + go handler.Handle(ctx) + time.Sleep(5 * time.Millisecond) + cancelFunc() + <-ctx.Done() // Check if channel is closed + wg.Wait() + return true + }, 1*time.Second, 5*time.Millisecond) + }) - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedErrMessage) - assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) - assert.True(t, hookHolder.called) - }) + t.Run("read error time out", func(t *testing.T) { + t.Run("should stop handler when read error timer runs out", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + clientMock := NewMockTransportClient(ctrl) + clientMock.EXPECT().ReadBytesFromClient(). + Return(nil, errors.New("random error")). + MinTimes(1) + clientMock.EXPECT().IsConnected(). + Return(true). + MinTimes(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionError, gomock.Eq(""), gomock.Nil(), gomock.Eq(ErrCouldNotReadMessageFromClient)). + MinTimes(1) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionOpened, gomock.Eq(""), gomock.Nil(), gomock.Nil()) + + protocolMock := NewMockProtocol(ctrl) + protocolMock.EXPECT().EventHandler(). + Return(eventHandlerMock). + MinTimes(1) + + engineMock := NewMockEngine(ctrl) + engineMock.EXPECT().TerminateAllSubscriptions(eventHandlerMock). + Do(func(_ EventHandler) { + wg.Done() + }). + Times(1) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + options := UniversalProtocolHandlerOptions{ + Logger: abstractlogger.Noop{}, + CustomSubscriptionUpdateInterval: 0, + CustomReadErrorTimeOut: 5 * time.Millisecond, + CustomEngine: engineMock, + } + handler, err := NewUniversalProtocolHandlerWithOptions(clientMock, protocolMock, nil, options) + require.NoError(t, err) + + assert.Eventually(t, func() bool { + go handler.Handle(ctx) + time.Sleep(30 * time.Millisecond) + wg.Wait() + return true + }, 1*time.Second, 5*time.Millisecond) }) - t.Run("subscription query", func(t *testing.T) { - executorPool, hookHolder := setupEngineV2(t, ctx, chatServer.URL) - - t.Run("should start subscription on start", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.SubscriptionLiveMessages) - require.NoError(t, err) - client.prepareStartMessage("1", payload).withoutError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() + t.Run("should continue running handler after intermittent read error", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) - time.Sleep(50 * time.Millisecond) - defer cancelFunc() + ctrl := gomock.NewController(t) + defer ctrl.Finish() - go sendChatMutation(t, chatServer.URL) - - require.Eventually(t, func() bool { - return client.hasMoreMessagesThan(0) - }, 1*time.Second, 10*time.Millisecond) - - expectedMessage := Message{ - Id: "1", - Type: MessageTypeData, - Payload: []byte(`{"data":{"messageAdded":{"text":"Hello World!","createdBy":"myuser"}}}`), + readErrorCounter := 0 + readErrorReturn := func() error { + var err error + if readErrorCounter == 0 { + err = errors.New("random error") } + readErrorCounter++ + return err + } + + clientMock := NewMockTransportClient(ctrl) + clientMock.EXPECT().ReadBytesFromClient(). + DoAndReturn(func() ([]byte, error) { + return nil, readErrorReturn() + }, + ). + MinTimes(1) + clientMock.EXPECT().IsConnected(). + Return(true). + MinTimes(1) + + eventHandlerMock := NewMockEventHandler(ctrl) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionError, gomock.Eq(""), gomock.Nil(), gomock.Eq(ErrCouldNotReadMessageFromClient)). + MinTimes(1) + eventHandlerMock.EXPECT().Emit(EventTypeOnConnectionOpened, gomock.Eq(""), gomock.Nil(), gomock.Nil()) + + protocolMock := NewMockProtocol(ctrl) + protocolMock.EXPECT().EventHandler(). + Return(eventHandlerMock). + MinTimes(1) + + engineMock := NewMockEngine(ctrl) + engineMock.EXPECT().TerminateAllSubscriptions(eventHandlerMock). + Do(func(_ EventHandler) { + wg.Done() + }). + Times(1) - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedMessage) - assert.Equal(t, 1, subscriptionHandler.ActiveSubscriptions()) - }) - - t.Run("should fail with validation error for invalid Subscription", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.InvalidSubscriptionLiveMessages) - require.NoError(t, err) - client.prepareStartMessage("1", payload).withoutError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() - - time.Sleep(10 * time.Millisecond) - cancelFunc() - - go sendChatMutation(t, chatServer.URL) - - require.Eventually(t, func() bool { - return client.hasMoreMessagesThan(0) - }, 1*time.Second, 10*time.Millisecond) - - messagesFromServer := client.readFromServer() - assert.Len(t, messagesFromServer, 1) - assert.Equal(t, "1", messagesFromServer[0].Id) - assert.Equal(t, MessageTypeError, messagesFromServer[0].Type) - assert.Equal(t, `[{"message":"differing fields for objectName 'a' on (potentially) same type","path":["subscription","messageAdded"]}]`, string(messagesFromServer[0].Payload)) - assert.Equal(t, 1, subscriptionHandler.ActiveSubscriptions()) - }) - - t.Run("should stop subscription on stop and send complete message to client", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - client.reconnect().prepareStopMessage("1").withoutError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() - - waitForCanceledSubscription := func() bool { - for subscriptionHandler.ActiveSubscriptions() > 0 { - } - return true - } - - assert.Eventually(t, waitForCanceledSubscription, 1*time.Second, 5*time.Millisecond) - assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) - - expectedMessage := Message{ - Id: "1", - Type: MessageTypeComplete, - Payload: nil, - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedMessage) - - cancelFunc() - }) - - t.Run("should interrupt subscription on start and return error message from hook", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - - payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.SubscriptionLiveMessages) - require.NoError(t, err) - - errMsg := "sub_interrupted" - hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { - return errors.New(errMsg) - } - - client.prepareStartMessage("1", payload).withoutError().and().send() - - ctx, cancelFunc := context.WithCancel(context.Background()) - handlerRoutineFunc := handlerRoutine(ctx) - go handlerRoutineFunc() + ctx, cancelFunc := context.WithCancel(context.Background()) + options := UniversalProtocolHandlerOptions{ + Logger: abstractlogger.Noop{}, + CustomSubscriptionUpdateInterval: 0, + CustomReadErrorTimeOut: 5 * time.Millisecond, + CustomEngine: engineMock, + } + handler, err := NewUniversalProtocolHandlerWithOptions(clientMock, protocolMock, nil, options) + require.NoError(t, err) + + assert.Eventually(t, func() bool { + go handler.Handle(ctx) time.Sleep(10 * time.Millisecond) cancelFunc() - - go sendChatMutation(t, chatServer.URL) - - require.Eventually(t, func() bool { - return client.hasMoreMessagesThan(0) - }, 1*time.Second, 10*time.Millisecond) - - jsonErrMessage, err := json.Marshal(graphql.RequestErrors{ - {Message: errMsg}, - }) - require.NoError(t, err) - expectedErrMessage := Message{ - Id: "1", - Type: MessageTypeError, - Payload: jsonErrMessage, - } - - messagesFromServer := client.readFromServer() - assert.Contains(t, messagesFromServer, expectedErrMessage) - assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) - assert.True(t, hookHolder.called) - }) + <-ctx.Done() // Check if channel is closed + wg.Wait() + return true + }, 1*time.Second, 5*time.Millisecond) }) - - t.Run("connection_terminate", func(t *testing.T) { - executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) - _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - - t.Run("should successfully disconnect from client", func(t *testing.T) { - client.prepareConnectionTerminateMessage().withoutError().and().send() - require.True(t, client.connected) - - ctx, cancelFunc := context.WithCancel(context.Background()) - - cancelFunc() - require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) - - assert.False(t, client.connected) - }) - }) - - t.Run("client is disconnected", func(t *testing.T) { - executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) - _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) - - t.Run("server should not read from client and stop handler", func(t *testing.T) { - err := client.Disconnect() - require.NoError(t, err) - require.False(t, client.connected) - - client.prepareConnectionInitMessage().withoutError() - ctx, cancelFunc := context.WithCancel(context.Background()) - - cancelFunc() - require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) - - assert.False(t, client.serverHasRead) - }) - }) - }) - -} - -func setupEngineV2(t *testing.T, ctx context.Context, chatServerURL string) (*ExecutorV2Pool, *websocketHook) { - chatSchemaBytes, err := subscriptiontesting.LoadSchemaFromExamplesDirectoryWithinPkg() - require.NoError(t, err) - - chatSchema, err := graphql.NewSchemaFromReader(bytes.NewBuffer(chatSchemaBytes)) - require.NoError(t, err) - - engineConf := graphql.NewEngineV2Configuration(chatSchema) - engineConf.SetDataSources([]plan.DataSourceConfiguration{ - { - RootNodes: []plan.TypeField{ - {TypeName: "Mutation", FieldNames: []string{"post"}}, - {TypeName: "Subscription", FieldNames: []string{"messageAdded"}}, - }, - ChildNodes: []plan.TypeField{ - {TypeName: "Message", FieldNames: []string{"text", "createdBy"}}, - }, - Factory: &graphql_datasource.Factory{ - HTTPClient: httpclient.DefaultNetHttpClient, - }, - Custom: graphql_datasource.ConfigJson(graphql_datasource.Configuration{ - Fetch: graphql_datasource.FetchConfiguration{ - URL: chatServerURL, - Method: http.MethodPost, - Header: nil, - }, - Subscription: graphql_datasource.SubscriptionConfiguration{ - URL: chatServerURL, - }, - }), - }, - }) - engineConf.SetFieldConfigurations([]plan.FieldConfiguration{ - { - TypeName: "Mutation", - FieldName: "post", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "roomName", - SourceType: plan.FieldArgumentSource, - }, - { - Name: "username", - SourceType: plan.FieldArgumentSource, - }, - { - Name: "text", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - { - TypeName: "Subscription", - FieldName: "messageAdded", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "roomName", - SourceType: plan.FieldArgumentSource, - }, - }, - }, }) - - hookHolder := &websocketHook{ - reqCtx: context.Background(), - } - engineConf.SetWebsocketBeforeStartHook(hookHolder) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost:8080", nil) - require.NoError(t, err) - - req.Header.Set("X-Other-Key", "x-other-value") - - initCtx := NewInitialHttpRequestContext(req) - - engine, err := graphql.NewExecutionEngineV2(initCtx, abstractlogger.NoopLogger, engineConf) - require.NoError(t, err) - - executorPool := NewExecutorV2Pool(engine, hookHolder.reqCtx) - - return executorPool, hookHolder -} - -func setupSubscriptionHandlerTest(t *testing.T, executorPool ExecutorPool) (subscriptionHandler *Handler, client *mockClient, routine handlerRoutine) { - return setupSubscriptionHandlerWithInitFuncTest(t, executorPool, nil) -} - -func setupSubscriptionHandlerWithInitFuncTest( - t *testing.T, - executorPool ExecutorPool, - initFunc WebsocketInitFunc, -) (subscriptionHandler *Handler, client *mockClient, routine handlerRoutine) { - client = newMockClient() - - var err error - subscriptionHandler, err = NewHandlerWithInitFunc(abstractlogger.NoopLogger, client, executorPool, initFunc) - require.NoError(t, err) - - routine = func(ctx context.Context) func() bool { - return func() bool { - subscriptionHandler.Handle(ctx) - return true - } - } - - return subscriptionHandler, client, routine -} - -func jsonizePayload(t *testing.T, payload interface{}) json.RawMessage { - jsonBytes, err := json.Marshal(payload) - require.NoError(t, err) - - return jsonBytes -} - -func sendChatMutation(t *testing.T, url string) { - reqBody, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.MutationSendMessage) - require.NoError(t, err) - - req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(reqBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - httpClient := http.Client{} - resp, err := httpClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) } diff --git a/v2/pkg/subscription/init.go b/v2/pkg/subscription/init.go index fb4eaed2c..f63a3a7e3 100644 --- a/v2/pkg/subscription/init.go +++ b/v2/pkg/subscription/init.go @@ -1,8 +1,17 @@ package subscription -import "encoding/json" +import ( + "context" + "encoding/json" +) + +// WebsocketInitFunc is called when the server receives connection init message from the client. +// This can be used to check initial payload to see whether to accept the websocket connection. +// Deprecated: Use websocket.InitFunc instead. +type WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) // InitPayload is a structure that is parsed from the websocket init message payload. +// Deprecated: Use websocket.InitPayload instead. type InitPayload json.RawMessage // GetString safely gets a string value from the payload. It returns an empty string if the @@ -25,7 +34,7 @@ func (p InitPayload) GetString(key string) string { return "" } -// Authorization is a short hand for getting the Authorization header from the +// Authorization is a shorthand for getting the Authorization header from the // payload. func (p InitPayload) Authorization() string { if value := p.GetString("Authorization"); value != "" { diff --git a/v2/pkg/subscription/legacy_handler.go b/v2/pkg/subscription/legacy_handler.go new file mode 100644 index 000000000..e49f5b1b9 --- /dev/null +++ b/v2/pkg/subscription/legacy_handler.go @@ -0,0 +1,494 @@ +package subscription + +import ( + "bytes" + "context" + "encoding/json" + "sync" + "time" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" +) + +const ( + MessageTypeConnectionInit = "connection_init" + MessageTypeConnectionAck = "connection_ack" + MessageTypeConnectionError = "connection_error" + MessageTypeConnectionTerminate = "connection_terminate" + MessageTypeConnectionKeepAlive = "ka" + MessageTypeStart = "start" + MessageTypeStop = "stop" + MessageTypeData = "data" + MessageTypeError = "error" + MessageTypeComplete = "complete" +) + +// Message defines the actual subscription message which will be passed from client to server and vice versa. +// +// Deprecated: Prefer using TransportClient that is based on byte slices instead of this Message struct. +type Message struct { + Id string `json:"id"` + Type string `json:"type"` + Payload json.RawMessage `json:"payload"` +} + +// Client provides an interface which can be implemented by any possible subscription client like websockets, mqtt, etc. +// +// Deprecated: Use TransportClient instead. +type Client interface { + // ReadFromClient will invoke a read operation from the client connection. + ReadFromClient() (*Message, error) + // WriteToClient will invoke a write operation to the client connection. + WriteToClient(Message) error + // IsConnected will indicate if a connection is still established. + IsConnected() bool + // Disconnect will close the connection between server and client. + Disconnect() error +} + +// Handler is the actual subscription handler which will keep track on how to handle messages coming from the client. +type Handler struct { + logger abstractlogger.Logger + // client will hold the subscription client implementation. + client Client + // keepAliveInterval is the actual interval on which the server send keep alive messages to the client. + keepAliveInterval time.Duration + // subscriptionUpdateInterval is the actual interval on which the server sends subscription updates to the client. + subscriptionUpdateInterval time.Duration + // subCancellations stores a map containing the cancellation functions to every active subscription. + subCancellations subscriptionCancellations + // executorPool is responsible to create and hold executors. + executorPool ExecutorPool + // bufferPool will hold buffers. + bufferPool *sync.Pool + // initFunc will check initial payload to see whether to accept the websocket connection. + initFunc WebsocketInitFunc +} + +func NewHandlerWithInitFunc( + logger abstractlogger.Logger, + client Client, + executorPool ExecutorPool, + initFunc WebsocketInitFunc, +) (*Handler, error) { + keepAliveInterval, err := time.ParseDuration(DefaultKeepAliveInterval) + if err != nil { + return nil, err + } + + subscriptionUpdateInterval, err := time.ParseDuration(DefaultSubscriptionUpdateInterval) + if err != nil { + return nil, err + } + + return &Handler{ + logger: logger, + client: client, + keepAliveInterval: keepAliveInterval, + subscriptionUpdateInterval: subscriptionUpdateInterval, + subCancellations: subscriptionCancellations{}, + executorPool: executorPool, + bufferPool: &sync.Pool{ + New: func() interface{} { + writer := graphql.NewEngineResultWriterFromBuffer(bytes.NewBuffer(make([]byte, 0, 1024))) + return &writer + }, + }, + initFunc: initFunc, + }, nil +} + +// NewHandler creates a new subscription handler. +func NewHandler(logger abstractlogger.Logger, client Client, executorPool ExecutorPool) (*Handler, error) { + return NewHandlerWithInitFunc(logger, client, executorPool, nil) +} + +// Handle will handle the subscription connection. +func (h *Handler) Handle(ctx context.Context) { + defer h.subCancellations.CancelAll() + + for { + if !h.client.IsConnected() { + h.logger.Debug("subscription.Handler.Handle()", + abstractlogger.String("message", "client has disconnected"), + ) + + return + } + + message, err := h.client.ReadFromClient() + if err != nil { + h.logger.Error("subscription.Handler.Handle()", + abstractlogger.Error(err), + abstractlogger.Any("message", message), + ) + + h.handleConnectionError("could not read message from client") + } else if message != nil { + switch message.Type { + case MessageTypeConnectionInit: + ctx, err = h.handleInit(ctx, message.Payload) + if err != nil { + h.terminateConnection("failed to accept the websocket connection") + return + } + + go h.handleKeepAlive(ctx) + case MessageTypeStart: + h.handleStart(ctx, message.Id, message.Payload) + case MessageTypeStop: + h.handleStop(message.Id) + case MessageTypeConnectionTerminate: + h.handleConnectionTerminate() + return + } + } + + select { + case <-ctx.Done(): + return + default: + continue + } + } +} + +// ChangeKeepAliveInterval can be used to change the keep alive interval. +func (h *Handler) ChangeKeepAliveInterval(d time.Duration) { + h.keepAliveInterval = d +} + +// ChangeSubscriptionUpdateInterval can be used to change the update interval. +func (h *Handler) ChangeSubscriptionUpdateInterval(d time.Duration) { + h.subscriptionUpdateInterval = d +} + +// handleInit will handle an init message. +func (h *Handler) handleInit(ctx context.Context, payload []byte) (extendedCtx context.Context, err error) { + if h.initFunc != nil { + var initPayload InitPayload + // decode initial payload + if len(payload) > 0 { + initPayload = payload + } + // check initial payload to see whether to accept the websocket connection + if extendedCtx, err = h.initFunc(ctx, initPayload); err != nil { + return extendedCtx, err + } + } else { + extendedCtx = ctx + } + + ackMessage := Message{ + Type: MessageTypeConnectionAck, + } + + if err = h.client.WriteToClient(ackMessage); err != nil { + return extendedCtx, err + } + + return extendedCtx, nil +} + +// handleStart will handle s start message. +func (h *Handler) handleStart(ctx context.Context, id string, payload []byte) { + executor, err := h.executorPool.Get(payload) + if err != nil { + h.logger.Error("subscription.Handler.handleStart()", + abstractlogger.Error(err), + ) + + h.handleError(id, graphql.RequestErrorsFromError(err)) + return + } + + if err = h.handleOnBeforeStart(executor); err != nil { + h.handleError(id, graphql.RequestErrorsFromError(err)) + return + } + + if executor.OperationType() == ast.OperationTypeSubscription { + ctx, subsErr := h.subCancellations.AddWithParent(id, ctx) + if subsErr != nil { + h.handleError(id, graphql.RequestErrorsFromError(subsErr)) + return + } + go h.startSubscription(ctx, id, executor) + return + } + + go h.handleNonSubscriptionOperation(ctx, id, executor) +} + +func (h *Handler) handleOnBeforeStart(executor Executor) error { + switch e := executor.(type) { + case *ExecutorV2: + if hook := e.engine.GetWebsocketBeforeStartHook(); hook != nil { + return hook.OnBeforeStart(e.reqCtx, e.operation) + } + } + + return nil +} + +// handleNonSubscriptionOperation will handle a non-subscription operation like a query or a mutation. +func (h *Handler) handleNonSubscriptionOperation(ctx context.Context, id string, executor Executor) { + defer func() { + err := h.executorPool.Put(executor) + if err != nil { + h.logger.Error("subscription.Handle.handleNonSubscriptionOperation()", + abstractlogger.Error(err), + ) + } + }() + + executor.SetContext(ctx) + buf := h.bufferPool.Get().(*graphql.EngineResultWriter) + buf.Reset() + + defer h.bufferPool.Put(buf) + + // err := executor.Execute(executionContext, node, buf) + err := executor.Execute(buf) + if err != nil { + h.logger.Error("subscription.Handle.handleNonSubscriptionOperation()", + abstractlogger.Error(err), + ) + + h.handleError(id, graphql.RequestErrorsFromError(err)) + return + } + + h.logger.Debug("subscription.Handle.handleNonSubscriptionOperation()", + abstractlogger.ByteString("execution_result", buf.Bytes()), + ) + + h.sendData(id, buf.Bytes()) + h.sendComplete(id) +} + +// startSubscription will invoke the actual subscription. +func (h *Handler) startSubscription(ctx context.Context, id string, executor Executor) { + defer func() { + err := h.executorPool.Put(executor) + if err != nil { + h.logger.Error("subscription.Handle.startSubscription()", + abstractlogger.Error(err), + ) + } + }() + + executor.SetContext(ctx) + buf := h.bufferPool.Get().(*graphql.EngineResultWriter) + buf.Reset() + + defer h.bufferPool.Put(buf) + + h.executeSubscription(buf, id, executor) + + for { + buf.Reset() + select { + case <-ctx.Done(): + return + case <-time.After(h.subscriptionUpdateInterval): + h.executeSubscription(buf, id, executor) + } + } + +} + +// executeSubscription will keep execution the subscription until it ends. +func (h *Handler) executeSubscription(buf *graphql.EngineResultWriter, id string, executor Executor) { + buf.SetFlushCallback(func(data []byte) { + h.logger.Debug("subscription.Handle.executeSubscription()", + abstractlogger.ByteString("execution_result", data), + ) + h.sendData(id, data) + }) + defer buf.SetFlushCallback(nil) + + err := executor.Execute(buf) + if err != nil { + h.logger.Error("subscription.Handle.executeSubscription()", + abstractlogger.Error(err), + ) + + h.handleError(id, graphql.RequestErrorsFromError(err)) + return + } + + if buf.Len() > 0 { + data := buf.Bytes() + h.logger.Debug("subscription.Handle.executeSubscription()", + abstractlogger.ByteString("execution_result", data), + ) + h.sendData(id, data) + } +} + +// handleStop will handle a stop message, +func (h *Handler) handleStop(id string) { + h.subCancellations.Cancel(id) + h.sendComplete(id) +} + +// sendData will send a data message to the client. +func (h *Handler) sendData(id string, responseData []byte) { + dataMessage := Message{ + Id: id, + Type: MessageTypeData, + Payload: responseData, + } + + err := h.client.WriteToClient(dataMessage) + if err != nil { + h.logger.Error("subscription.Handler.sendData()", + abstractlogger.Error(err), + ) + } +} + +// nolint +// sendComplete will send a complete message to the client. +func (h *Handler) sendComplete(id string) { + completeMessage := Message{ + Id: id, + Type: MessageTypeComplete, + Payload: nil, + } + + err := h.client.WriteToClient(completeMessage) + if err != nil { + h.logger.Error("subscription.Handler.sendComplete()", + abstractlogger.Error(err), + ) + } +} + +// handleConnectionTerminate will handle a connection terminate message. +func (h *Handler) handleConnectionTerminate() { + err := h.client.Disconnect() + if err != nil { + h.logger.Error("subscription.Handler.handleConnectionTerminate()", + abstractlogger.Error(err), + ) + } +} + +// handleKeepAlive will handle the keep alive loop. +func (h *Handler) handleKeepAlive(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(h.keepAliveInterval): + h.sendKeepAlive() + } + } +} + +// sendKeepAlive will send a keep alive message to the client. +func (h *Handler) sendKeepAlive() { + keepAliveMessage := Message{ + Type: MessageTypeConnectionKeepAlive, + } + + err := h.client.WriteToClient(keepAliveMessage) + if err != nil { + h.logger.Error("subscription.Handler.sendKeepAlive()", + abstractlogger.Error(err), + ) + } +} + +func (h *Handler) terminateConnection(reason interface{}) { + payloadBytes, err := json.Marshal(reason) + if err != nil { + h.logger.Error("subscription.Handler.terminateConnection()", + abstractlogger.Error(err), + abstractlogger.Any("errorPayload", reason), + ) + } + + connectionErrorMessage := Message{ + Type: MessageTypeConnectionTerminate, + Payload: payloadBytes, + } + + err = h.client.WriteToClient(connectionErrorMessage) + if err != nil { + h.logger.Error("subscription.Handler.terminateConnection()", + abstractlogger.Error(err), + ) + + err := h.client.Disconnect() + if err != nil { + h.logger.Error("subscription.Handler.terminateConnection()", + abstractlogger.Error(err), + ) + } + } +} + +// handleConnectionError will handle a connection error message. +func (h *Handler) handleConnectionError(errorPayload interface{}) { + payloadBytes, err := json.Marshal(errorPayload) + if err != nil { + h.logger.Error("subscription.Handler.handleConnectionError()", + abstractlogger.Error(err), + abstractlogger.Any("errorPayload", errorPayload), + ) + } + + connectionErrorMessage := Message{ + Type: MessageTypeConnectionError, + Payload: payloadBytes, + } + + err = h.client.WriteToClient(connectionErrorMessage) + if err != nil { + h.logger.Error("subscription.Handler.handleConnectionError()", + abstractlogger.Error(err), + ) + + err := h.client.Disconnect() + if err != nil { + h.logger.Error("subscription.Handler.handleError()", + abstractlogger.Error(err), + ) + } + } +} + +// handleError will handle an error message. +func (h *Handler) handleError(id string, errors graphql.RequestErrors) { + payloadBytes, err := json.Marshal(errors) + if err != nil { + h.logger.Error("subscription.Handler.handleError()", + abstractlogger.Error(err), + abstractlogger.Any("errors", errors), + ) + } + + errorMessage := Message{ + Id: id, + Type: MessageTypeError, + Payload: payloadBytes, + } + + err = h.client.WriteToClient(errorMessage) + if err != nil { + h.logger.Error("subscription.Handler.handleError()", + abstractlogger.Error(err), + ) + } +} + +// ActiveSubscriptions will return the actual number of active subscriptions for that client. +func (h *Handler) ActiveSubscriptions() int { + return h.subCancellations.Len() +} diff --git a/v2/pkg/subscription/legacy_handler_test.go b/v2/pkg/subscription/legacy_handler_test.go new file mode 100644 index 000000000..9cdbcdb81 --- /dev/null +++ b/v2/pkg/subscription/legacy_handler_test.go @@ -0,0 +1,670 @@ +package subscription + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/starwars" + "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/subscriptiontesting" +) + +type handlerRoutine func(ctx context.Context) func() bool + +type websocketHook struct { + called bool + reqCtx context.Context + hook func(reqCtx context.Context, operation *graphql.Request) error +} + +func (w *websocketHook) OnBeforeStart(reqCtx context.Context, operation *graphql.Request) error { + w.called = true + if w.hook != nil { + return w.hook(reqCtx, operation) + } + return nil +} + +func TestHandler_Handle(t *testing.T) { + starwars.SetRelativePathToStarWarsPackage("../starwars") + + t.Run("engine v2", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + chatServer := httptest.NewServer(subscriptiontesting.ChatGraphQLEndpointHandler()) + defer chatServer.Close() + + t.Run("connection_init", func(t *testing.T) { + var initPayloadAuthorization string + + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) + _, client, handlerRoutine := setupSubscriptionHandlerWithInitFuncTest(t, executorPool, func(ctx context.Context, initPayload InitPayload) (context.Context, error) { + if initPayloadAuthorization == "" { + return ctx, nil + } + + if initPayloadAuthorization != initPayload.Authorization() { + return nil, fmt.Errorf("unknown user: %s", initPayload.Authorization()) + } + + return ctx, nil + }) + + t.Run("should send connection error message when error on read occurrs", func(t *testing.T) { + client.prepareConnectionInitMessage().withError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + + cancelFunc() + require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) + + expectedMessage := Message{ + Type: MessageTypeConnectionError, + Payload: jsonizePayload(t, "could not read message from client"), + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedMessage) + }) + + t.Run("should successfully init connection and respond with ack", func(t *testing.T) { + client.reconnect().and().prepareConnectionInitMessage().withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + + cancelFunc() + require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) + + expectedMessage := Message{ + Type: MessageTypeConnectionAck, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedMessage) + }) + + t.Run("should send connection error message when error on check initial payload occurrs", func(t *testing.T) { + initPayloadAuthorization = "123" + defer func() { initPayloadAuthorization = "" }() + + client.reconnect().and().prepareConnectionInitMessageWithPayload([]byte(`{"Authorization": "111"}`)).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + + cancelFunc() + require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) + + expectedMessage := Message{ + Type: MessageTypeConnectionTerminate, + Payload: jsonizePayload(t, "failed to accept the websocket connection"), + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedMessage) + }) + + t.Run("should successfully init connection and respond with ack when initial payload successfully occurred ", func(t *testing.T) { + initPayloadAuthorization = "123" + defer func() { initPayloadAuthorization = "" }() + + client.reconnect().and().prepareConnectionInitMessageWithPayload([]byte(`{"Authorization": "123"}`)).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + + cancelFunc() + require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) + + expectedMessage := Message{ + Type: MessageTypeConnectionAck, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedMessage) + }) + }) + + t.Run("connection_keep_alive", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + t.Run("should successfully send keep alive messages after connection_init", func(t *testing.T) { + keepAliveInterval, err := time.ParseDuration("5ms") + require.NoError(t, err) + + subscriptionHandler.ChangeKeepAliveInterval(keepAliveInterval) + + client.prepareConnectionInitMessage().withoutError().and().send() + ctx, cancelFunc := context.WithCancel(context.Background()) + + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + expectedMessage := Message{ + Type: MessageTypeConnectionKeepAlive, + } + + messagesFromServer := client.readFromServer() + waitForKeepAliveMessage := func() bool { + for len(messagesFromServer) < 2 { + messagesFromServer = client.readFromServer() + } + return true + } + + assert.Eventually(t, waitForKeepAliveMessage, 1*time.Second, 5*time.Millisecond) + assert.Contains(t, messagesFromServer, expectedMessage) + + cancelFunc() + }) + }) + + t.Run("erroneous operation(s)", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) + _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + ctx, cancelFunc := context.WithCancel(context.Background()) + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + t.Run("should send error when query contains syntax errors", func(t *testing.T) { + payload := []byte(`{"operationName": "Broken", "query Broken {": "", "variables": null}`) + client.prepareStartMessage("1", payload).withoutError().send() + + waitForClientHavingAMessage := func() bool { + return client.hasMoreMessagesThan(0) + } + require.Eventually(t, waitForClientHavingAMessage, 5*time.Second, 5*time.Millisecond) + + expectedMessage := Message{ + Id: "1", + Type: MessageTypeError, + Payload: []byte(`[{"message":"document doesn't contain any executable operation"}]`), + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedMessage) + }) + + cancelFunc() + }) + + t.Run("non-subscription query", func(t *testing.T) { + executorPool, hookHolder := setupEngineV2(t, ctx, chatServer.URL) + + t.Run("should process query and return error when query is not valid", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.InvalidOperation) + require.NoError(t, err) + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + cancelFunc() + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + waitForClientHavingAMessage := func() bool { + return client.hasMoreMessagesThan(0) + } + require.Eventually(t, waitForClientHavingAMessage, 1*time.Second, 5*time.Millisecond) + + expectedErrorMessage := Message{ + Id: "1", + Type: MessageTypeError, + Payload: []byte(`[{"message":"field: serverName not defined on type: Query","path":["query","serverName"]}]`), + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedErrorMessage) + assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + }) + + t.Run("should process and send result for a query", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.MutationSendMessage) + require.NoError(t, err) + + hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { + assert.Equal(t, hookHolder.reqCtx, ctx) + assert.Contains(t, operation.Query, "mutation SendMessage") + return nil + } + defer func() { + hookHolder.hook = nil + }() + + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + waitForClientHavingTwoMessages := func() bool { + return client.hasMoreMessagesThan(1) + } + require.Eventually(t, waitForClientHavingTwoMessages, 60*time.Second, 5*time.Millisecond) + + expectedDataMessage := Message{ + Id: "1", + Type: MessageTypeData, + Payload: []byte(`{"data":{"post":{"text":"Hello World!","createdBy":"myuser"}}}`), + } + + expectedCompleteMessage := Message{ + Id: "1", + Type: MessageTypeComplete, + Payload: nil, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedDataMessage) + assert.Contains(t, messagesFromServer, expectedCompleteMessage) + assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + assert.True(t, hookHolder.called) + }) + + t.Run("should process and send error message from hook for a query", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.MutationSendMessage) + require.NoError(t, err) + + errMsg := "error_on_operation" + hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { + return errors.New(errMsg) + } + defer func() { + hookHolder.hook = nil + }() + + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + cancelFunc() + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + waitForClientHavingTwoMessages := func() bool { + return client.hasMoreMessagesThan(0) + } + require.Eventually(t, waitForClientHavingTwoMessages, 5*time.Second, 5*time.Millisecond) + + jsonErrMessage, err := json.Marshal(graphql.RequestErrors{ + {Message: errMsg}, + }) + require.NoError(t, err) + expectedErrMessage := Message{ + Id: "1", + Type: MessageTypeError, + Payload: jsonErrMessage, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedErrMessage) + assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + assert.True(t, hookHolder.called) + }) + + }) + + t.Run("subscription query", func(t *testing.T) { + executorPool, hookHolder := setupEngineV2(t, ctx, chatServer.URL) + + t.Run("should start subscription on start", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.SubscriptionLiveMessages) + require.NoError(t, err) + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + time.Sleep(50 * time.Millisecond) + defer cancelFunc() + + go sendChatMutation(t, chatServer.URL) + + require.Eventually(t, func() bool { + return client.hasMoreMessagesThan(0) + }, 1*time.Second, 10*time.Millisecond) + + expectedMessage := Message{ + Id: "1", + Type: MessageTypeData, + Payload: []byte(`{"data":{"messageAdded":{"text":"Hello World!","createdBy":"myuser"}}}`), + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedMessage) + assert.Equal(t, 1, subscriptionHandler.ActiveSubscriptions()) + }) + + t.Run("id collisions should not be allowed", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.SubscriptionLiveMessages) + require.NoError(t, err) + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + time.Sleep(10 * time.Millisecond) + defer cancelFunc() + + go sendChatMutation(t, chatServer.URL) + + require.Eventually(t, func() bool { + return client.hasMoreMessagesThan(0) + }, 5*time.Second, 10*time.Millisecond) + + assert.Equal(t, 1, subscriptionHandler.ActiveSubscriptions()) + + client.prepareStartMessage("1", payload).withoutError().and().send() + require.Eventually(t, func() bool { + return client.hasMoreMessagesThan(1) + }, 5*time.Second, 10*time.Millisecond) + + messagesFromServer := client.readFromServer() + // There are two messages in this slice. The first one is a data message for the first start message + // The second one is an error message because we tried to create a new subscription with an already existed + // id. + + expectedDataMessage := Message{ + Id: "1", + Type: MessageTypeData, + Payload: []byte(`{"data":{"messageAdded":{"text":"Hello World!","createdBy":"myuser"}}}`), + } + assert.Contains(t, messagesFromServer, expectedDataMessage) + + expectedErrorMessage := Message{ + Id: "1", + Type: MessageTypeError, + Payload: []byte(`[{"message":"subscriber id already exists: 1"}]`), + } + assert.Contains(t, messagesFromServer, expectedErrorMessage) + }) + + t.Run("should fail with validation error for invalid Subscription", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.InvalidSubscriptionLiveMessages) + require.NoError(t, err) + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + time.Sleep(10 * time.Millisecond) + cancelFunc() + + go sendChatMutation(t, chatServer.URL) + + require.Eventually(t, func() bool { + return client.hasMoreMessagesThan(0) + }, 1*time.Second, 10*time.Millisecond) + + messagesFromServer := client.readFromServer() + assert.Len(t, messagesFromServer, 1) + assert.Equal(t, "1", messagesFromServer[0].Id) + assert.Equal(t, MessageTypeError, messagesFromServer[0].Type) + assert.Equal(t, `[{"message":"differing fields for objectName 'a' on (potentially) same type","path":["subscription","messageAdded"]}]`, string(messagesFromServer[0].Payload)) + assert.Equal(t, 1, subscriptionHandler.ActiveSubscriptions()) + }) + + t.Run("should stop subscription on stop and send complete message to client", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + client.reconnect().prepareStopMessage("1").withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + waitForCanceledSubscription := func() bool { + for subscriptionHandler.ActiveSubscriptions() > 0 { + } + return true + } + + assert.Eventually(t, waitForCanceledSubscription, 1*time.Second, 5*time.Millisecond) + assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + + expectedMessage := Message{ + Id: "1", + Type: MessageTypeComplete, + Payload: nil, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedMessage) + + cancelFunc() + }) + + t.Run("should interrupt subscription on start and return error message from hook", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + payload, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.SubscriptionLiveMessages) + require.NoError(t, err) + + errMsg := "sub_interrupted" + hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { + return errors.New(errMsg) + } + + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + time.Sleep(10 * time.Millisecond) + cancelFunc() + + go sendChatMutation(t, chatServer.URL) + + require.Eventually(t, func() bool { + return client.hasMoreMessagesThan(0) + }, 1*time.Second, 10*time.Millisecond) + + jsonErrMessage, err := json.Marshal(graphql.RequestErrors{ + {Message: errMsg}, + }) + require.NoError(t, err) + expectedErrMessage := Message{ + Id: "1", + Type: MessageTypeError, + Payload: jsonErrMessage, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedErrMessage) + assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + assert.True(t, hookHolder.called) + }) + }) + + t.Run("connection_terminate", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) + _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + t.Run("should successfully disconnect from client", func(t *testing.T) { + client.prepareConnectionTerminateMessage().withoutError().and().send() + require.True(t, client.connected) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + cancelFunc() + require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) + + assert.False(t, client.connected) + }) + }) + + t.Run("client is disconnected", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) + _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + t.Run("server should not read from client and stop handler", func(t *testing.T) { + err := client.Disconnect() + require.NoError(t, err) + require.False(t, client.connected) + + client.prepareConnectionInitMessage().withoutError() + ctx, cancelFunc := context.WithCancel(context.Background()) + + cancelFunc() + require.Eventually(t, handlerRoutine(ctx), 1*time.Second, 5*time.Millisecond) + + assert.False(t, client.serverHasRead) + }) + }) + }) + +} + +func setupEngineV2(t *testing.T, ctx context.Context, chatServerURL string) (*ExecutorV2Pool, *websocketHook) { + chatSchemaBytes, err := subscriptiontesting.LoadSchemaFromExamplesDirectoryWithinPkg() + require.NoError(t, err) + + chatSchema, err := graphql.NewSchemaFromReader(bytes.NewBuffer(chatSchemaBytes)) + require.NoError(t, err) + + engineConf := graphql.NewEngineV2Configuration(chatSchema) + engineConf.SetDataSources([]plan.DataSourceConfiguration{ + { + RootNodes: []plan.TypeField{ + {TypeName: "Mutation", FieldNames: []string{"post"}}, + {TypeName: "Subscription", FieldNames: []string{"messageAdded"}}, + }, + ChildNodes: []plan.TypeField{ + {TypeName: "Message", FieldNames: []string{"text", "createdBy"}}, + }, + Factory: &graphql_datasource.Factory{ + HTTPClient: httpclient.DefaultNetHttpClient, + }, + Custom: graphql_datasource.ConfigJson(graphql_datasource.Configuration{ + Fetch: graphql_datasource.FetchConfiguration{ + URL: chatServerURL, + Method: http.MethodPost, + Header: nil, + }, + Subscription: graphql_datasource.SubscriptionConfiguration{ + URL: chatServerURL, + }, + }), + }, + }) + engineConf.SetFieldConfigurations([]plan.FieldConfiguration{ + { + TypeName: "Mutation", + FieldName: "post", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "roomName", + SourceType: plan.FieldArgumentSource, + }, + { + Name: "username", + SourceType: plan.FieldArgumentSource, + }, + { + Name: "text", + SourceType: plan.FieldArgumentSource, + }, + }, + }, + { + TypeName: "Subscription", + FieldName: "messageAdded", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "roomName", + SourceType: plan.FieldArgumentSource, + }, + }, + }, + }) + + hookHolder := &websocketHook{ + reqCtx: context.Background(), + } + engineConf.SetWebsocketBeforeStartHook(hookHolder) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost:8080", nil) + require.NoError(t, err) + + req.Header.Set("X-Other-Key", "x-other-value") + + initCtx := NewInitialHttpRequestContext(req) + + engine, err := graphql.NewExecutionEngineV2(initCtx, abstractlogger.NoopLogger, engineConf) + require.NoError(t, err) + + executorPool := NewExecutorV2Pool(engine, hookHolder.reqCtx) + + return executorPool, hookHolder +} + +func setupSubscriptionHandlerTest(t *testing.T, executorPool ExecutorPool) (subscriptionHandler *Handler, client *mockClient, routine handlerRoutine) { + return setupSubscriptionHandlerWithInitFuncTest(t, executorPool, nil) +} + +func setupSubscriptionHandlerWithInitFuncTest( + t *testing.T, + executorPool ExecutorPool, + initFunc WebsocketInitFunc, +) (subscriptionHandler *Handler, client *mockClient, routine handlerRoutine) { + client = newMockClient() + + var err error + subscriptionHandler, err = NewHandlerWithInitFunc(abstractlogger.NoopLogger, client, executorPool, initFunc) + require.NoError(t, err) + + routine = func(ctx context.Context) func() bool { + return func() bool { + subscriptionHandler.Handle(ctx) + return true + } + } + + return subscriptionHandler, client, routine +} + +func jsonizePayload(t *testing.T, payload interface{}) json.RawMessage { + jsonBytes, err := json.Marshal(payload) + require.NoError(t, err) + + return jsonBytes +} + +func sendChatMutation(t *testing.T, url string) { + reqBody, err := subscriptiontesting.GraphQLRequestForOperation(subscriptiontesting.MutationSendMessage) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(reqBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + httpClient := http.Client{} + resp, err := httpClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/v2/pkg/subscription/time_out.go b/v2/pkg/subscription/time_out.go new file mode 100644 index 000000000..329065910 --- /dev/null +++ b/v2/pkg/subscription/time_out.go @@ -0,0 +1,38 @@ +package subscription + +import ( + "context" + "time" + + "github.com/jensneuse/abstractlogger" +) + +// TimeOutParams is a struct to configure a TimeOutChecker. +type TimeOutParams struct { + Name string + Logger abstractlogger.Logger + TimeOutContext context.Context + TimeOutAction func() + TimeOutDuration time.Duration +} + +// TimeOutChecker is a function that can be used in a go routine to perform a time-out action +// after a specific duration or prevent the time-out action by canceling the time-out context before. +// Use TimeOutParams for configuration. +func TimeOutChecker(params TimeOutParams) { + timer := time.NewTimer(params.TimeOutDuration) + defer timer.Stop() + + for { + select { + case <-params.TimeOutContext.Done(): + return + case <-timer.C: + params.Logger.Error("time out happened", + abstractlogger.String("name", params.Name), + ) + params.TimeOutAction() + return + } + } +} diff --git a/v2/pkg/subscription/time_out_test.go b/v2/pkg/subscription/time_out_test.go new file mode 100644 index 000000000..72968fb6c --- /dev/null +++ b/v2/pkg/subscription/time_out_test.go @@ -0,0 +1,59 @@ +package subscription + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" +) + +func TestTimeOutChecker(t *testing.T) { + t.Run("should stop timer if context is done before", func(t *testing.T) { + timeOutActionExecuted := false + timeOutAction := func() { + timeOutActionExecuted = true + } + + timeOutCtx, timeOutCancel := context.WithCancel(context.Background()) + params := TimeOutParams{ + Name: "", + Logger: abstractlogger.Noop{}, + TimeOutContext: timeOutCtx, + TimeOutAction: timeOutAction, + TimeOutDuration: 100 * time.Millisecond, + } + go TimeOutChecker(params) + time.Sleep(5 * time.Millisecond) + timeOutCancel() + <-timeOutCtx.Done() + assert.False(t, timeOutActionExecuted) + }) + + t.Run("should stop process if timer runs out", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + + timeOutActionExecuted := false + timeOutAction := func() { + timeOutActionExecuted = true + wg.Done() + } + + timeOutCtx, timeOutCancel := context.WithCancel(context.Background()) + defer timeOutCancel() + + params := TimeOutParams{ + Name: "", + Logger: abstractlogger.Noop{}, + TimeOutContext: timeOutCtx, + TimeOutAction: timeOutAction, + TimeOutDuration: 10 * time.Millisecond, + } + go TimeOutChecker(params) + wg.Wait() + assert.True(t, timeOutActionExecuted) + }) +} diff --git a/v2/pkg/subscription/transport_client.go b/v2/pkg/subscription/transport_client.go new file mode 100644 index 000000000..c5ce7b2ab --- /dev/null +++ b/v2/pkg/subscription/transport_client.go @@ -0,0 +1,27 @@ +package subscription + +import ( + "errors" +) + +//go:generate mockgen -destination=transport_client_mock_test.go -package=subscription . TransportClient + +// ErrTransportClientClosedConnection is an error to indicate that the transport client is using closed connection. +var ErrTransportClientClosedConnection = errors.New("transport client has a closed connection") + +// TransportClient provides an interface that can be implemented by any possible subscription client like websockets, mqtt, etc. +// It operates with raw byte slices. +type TransportClient interface { + // ReadBytesFromClient will invoke a read operation from the client connection and return a byte slice. + // This function should return ErrTransportClientClosedConnection when reading on a closed connection. + ReadBytesFromClient() ([]byte, error) + // WriteBytesToClient will invoke a write operation to the client connection using a byte slice. + // This function should return ErrTransportClientClosedConnection when writing on a closed connection. + WriteBytesToClient([]byte) error + // IsConnected will indicate if a connection is still established. + IsConnected() bool + // Disconnect will close the connection between server and client. + Disconnect() error + // DisconnectWithReason will close the connection but is also able to process a reason for closure. + DisconnectWithReason(reason interface{}) error +} diff --git a/v2/pkg/subscription/transport_client_mock_test.go b/v2/pkg/subscription/transport_client_mock_test.go new file mode 100644 index 000000000..cf1b3850f --- /dev/null +++ b/v2/pkg/subscription/transport_client_mock_test.go @@ -0,0 +1,105 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: TransportClient) + +// Package subscription is a generated GoMock package. +package subscription + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockTransportClient is a mock of TransportClient interface. +type MockTransportClient struct { + ctrl *gomock.Controller + recorder *MockTransportClientMockRecorder +} + +// MockTransportClientMockRecorder is the mock recorder for MockTransportClient. +type MockTransportClientMockRecorder struct { + mock *MockTransportClient +} + +// NewMockTransportClient creates a new mock instance. +func NewMockTransportClient(ctrl *gomock.Controller) *MockTransportClient { + mock := &MockTransportClient{ctrl: ctrl} + mock.recorder = &MockTransportClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTransportClient) EXPECT() *MockTransportClientMockRecorder { + return m.recorder +} + +// Disconnect mocks base method. +func (m *MockTransportClient) Disconnect() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Disconnect") + ret0, _ := ret[0].(error) + return ret0 +} + +// Disconnect indicates an expected call of Disconnect. +func (mr *MockTransportClientMockRecorder) Disconnect() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockTransportClient)(nil).Disconnect)) +} + +// DisconnectWithReason mocks base method. +func (m *MockTransportClient) DisconnectWithReason(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DisconnectWithReason", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DisconnectWithReason indicates an expected call of DisconnectWithReason. +func (mr *MockTransportClientMockRecorder) DisconnectWithReason(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectWithReason", reflect.TypeOf((*MockTransportClient)(nil).DisconnectWithReason), arg0) +} + +// IsConnected mocks base method. +func (m *MockTransportClient) IsConnected() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsConnected") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsConnected indicates an expected call of IsConnected. +func (mr *MockTransportClientMockRecorder) IsConnected() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockTransportClient)(nil).IsConnected)) +} + +// ReadBytesFromClient mocks base method. +func (m *MockTransportClient) ReadBytesFromClient() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadBytesFromClient") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadBytesFromClient indicates an expected call of ReadBytesFromClient. +func (mr *MockTransportClientMockRecorder) ReadBytesFromClient() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBytesFromClient", reflect.TypeOf((*MockTransportClient)(nil).ReadBytesFromClient)) +} + +// WriteBytesToClient mocks base method. +func (m *MockTransportClient) WriteBytesToClient(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteBytesToClient", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteBytesToClient indicates an expected call of WriteBytesToClient. +func (mr *MockTransportClientMockRecorder) WriteBytesToClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteBytesToClient", reflect.TypeOf((*MockTransportClient)(nil).WriteBytesToClient), arg0) +} diff --git a/v2/pkg/subscription/websocket/client.go b/v2/pkg/subscription/websocket/client.go new file mode 100644 index 000000000..003eccdab --- /dev/null +++ b/v2/pkg/subscription/websocket/client.go @@ -0,0 +1,187 @@ +package websocket + +import ( + "errors" + "io" + "net" + "sync" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +// CloseReason is type that is used to provide a close reason to Client.DisconnectWithReason. +type CloseReason ws.Frame + +// CompiledCloseReason is a pre-compiled close reason to be provided to Client.DisconnectWithReason. +type CompiledCloseReason []byte + +var ( + CompiledCloseReasonNormal CompiledCloseReason = ws.MustCompileFrame( + ws.NewCloseFrame(ws.NewCloseFrameBody( + ws.StatusNormalClosure, "Normal Closure", + )), + ) + CompiledCloseReasonInternalServerError CompiledCloseReason = ws.MustCompileFrame( + ws.NewCloseFrame(ws.NewCloseFrameBody( + ws.StatusInternalServerError, "Internal Server Error", + )), + ) +) + +// NewCloseReason is used to compose a close frame with code and reason message. +func NewCloseReason(code uint16, reason string) CloseReason { + wsCloseFrame := ws.NewCloseFrame(ws.NewCloseFrameBody( + ws.StatusCode(code), reason, + )) + return CloseReason(wsCloseFrame) +} + +// Client is an actual implementation of the subscription client interface. +type Client struct { + logger abstractlogger.Logger + // clientConn holds the actual connection to the client. + clientConn net.Conn + // isClosedConnection indicates if the websocket connection is closed. + isClosedConnection bool + mu *sync.RWMutex +} + +// NewClient will create a new websocket subscription client. +func NewClient(logger abstractlogger.Logger, clientConn net.Conn) *Client { + return &Client{ + logger: logger, + clientConn: clientConn, + mu: &sync.RWMutex{}, + } +} + +// ReadBytesFromClient will read a subscription message from the websocket client. +func (c *Client) ReadBytesFromClient() ([]byte, error) { + if !c.IsConnected() { + return nil, subscription.ErrTransportClientClosedConnection + } + + data, opCode, err := wsutil.ReadClientData(c.clientConn) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.ErrUnexpectedEOF) { + c.changeConnectionStateToClosed() + return nil, subscription.ErrTransportClientClosedConnection + } else if err != nil { + if c.isClosedConnectionError(err) { + return nil, subscription.ErrTransportClientClosedConnection + } + + c.logger.Error("websocket.Client.ReadBytesFromClient: after reading from client", + abstractlogger.Error(err), + abstractlogger.ByteString("data", data), + abstractlogger.Any("opCode", opCode), + ) + + c.isClosedConnectionError(err) + + return nil, err + } + + return data, nil +} + +// WriteBytesToClient will write a subscription message to the websocket client. +func (c *Client) WriteBytesToClient(message []byte) error { + if !c.IsConnected() { + return subscription.ErrTransportClientClosedConnection + } + + err := wsutil.WriteServerMessage(c.clientConn, ws.OpText, message) + if errors.Is(err, io.ErrClosedPipe) { + c.changeConnectionStateToClosed() + return subscription.ErrTransportClientClosedConnection + } else if err != nil { + c.logger.Error("websocket.Client.WriteBytesToClient: after writing to client", + abstractlogger.Error(err), + abstractlogger.ByteString("message", message), + ) + + return err + } + + return nil +} + +// IsConnected will indicate if the websocket connection is still established. +func (c *Client) IsConnected() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return !c.isClosedConnection +} + +// Disconnect will close the websocket connection. +func (c *Client) Disconnect() error { + c.logger.Debug("websocket.Client.Disconnect: before disconnect", + abstractlogger.String("message", "disconnecting client"), + ) + c.changeConnectionStateToClosed() + return c.clientConn.Close() +} + +// DisconnectWithReason will close the websocket and provide the close code and reason. +// It can only consume CloseReason or CompiledCloseReason. +func (c *Client) DisconnectWithReason(reason interface{}) error { + var err error + switch reason := reason.(type) { + case CloseReason: + err = c.writeFrame(ws.Frame(reason)) + case CompiledCloseReason: + err = c.writeCompiledFrame(reason) + default: + c.logger.Error("websocket.Client.DisconnectWithReason: on reason/frame parsing", + abstractlogger.String("message", "unknown reason provided"), + ) + frame := NewCloseReason(4400, "unknown reason") + err = c.writeFrame(ws.Frame(frame)) + } + + c.logger.Debug("websocket.Client.DisconnectWithReason: before sending close frame", + abstractlogger.String("message", "disconnecting client"), + ) + + if err != nil { + c.logger.Error("websocket.Client.DisconnectWithReason: after writing close reason", + abstractlogger.Error(err), + ) + return err + } + + return c.Disconnect() +} + +func (c *Client) writeFrame(frame ws.Frame) error { + return ws.WriteFrame(c.clientConn, frame) +} + +func (c *Client) writeCompiledFrame(compiledFrame []byte) error { + _, err := c.clientConn.Write(compiledFrame) + return err +} + +// isClosedConnectionError will indicate if the given error is a connection closed error. +func (c *Client) isClosedConnectionError(err error) bool { + c.mu.Lock() + defer c.mu.Unlock() + var closedErr wsutil.ClosedError + if errors.As(err, &closedErr) { + c.isClosedConnection = true + } + return c.isClosedConnection +} + +func (c *Client) changeConnectionStateToClosed() { + c.mu.Lock() + defer c.mu.Unlock() + c.isClosedConnection = true +} + +// Interface Guard +var _ subscription.TransportClient = (*Client)(nil) diff --git a/v2/pkg/subscription/websocket/client_test.go b/v2/pkg/subscription/websocket/client_test.go new file mode 100644 index 000000000..85be6940e --- /dev/null +++ b/v2/pkg/subscription/websocket/client_test.go @@ -0,0 +1,461 @@ +package websocket + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +type testServerWebsocketResponse struct { + data []byte + opCode ws.OpCode + statusCode ws.StatusCode + closeReason string + err error +} + +func TestClient_WriteToClient(t *testing.T) { + t.Run("should write successfully to client", func(t *testing.T) { + connToServer, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + messageToClient := []byte(`{ + "id": "1", + "type": "data", + "payload": {"data":null} + }`) + + go func() { + err := websocketClient.WriteBytesToClient(messageToClient) + assert.NoError(t, err) + }() + + data, opCode, err := wsutil.ReadServerData(connToServer) + require.NoError(t, err) + require.Equal(t, ws.OpText, opCode) + + time.Sleep(10 * time.Millisecond) + assert.Equal(t, messageToClient, data) + }) + + t.Run("should not write to client when connection is closed", func(t *testing.T) { + t.Run("when not wrapped", func(t *testing.T) { + t.Run("io: read/write on closed pipe", func(t *testing.T) { + connToServer, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + err := connToServer.Close() + require.NoError(t, err) + + err = websocketClient.WriteBytesToClient([]byte("")) + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + }) + + t.Run("when wrapped", func(t *testing.T) { + t.Run("io: read/write on closed pipe", func(t *testing.T) { + connToClient := FakeConn{} + wrappedErr := fmt.Errorf("outside wrapper: %w", + fmt.Errorf("inner wrapper: %w", + io.ErrClosedPipe, + ), + ) + connToClient.setWriteReturns(0, wrappedErr) + websocketClient := NewClient(abstractlogger.NoopLogger, &connToClient) + + err := websocketClient.WriteBytesToClient([]byte("message")) + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + }) + }) +} + +func TestClient_ReadFromClient(t *testing.T) { + t.Run("should successfully read from client", func(t *testing.T) { + connToServer, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + messageToServer := []byte(`{ + "id": "1", + "type": "data", + "payload": {"data":null} + }`) + + go func() { + err := wsutil.WriteClientText(connToServer, messageToServer) + require.NoError(t, err) + }() + + time.Sleep(10 * time.Millisecond) + + messageFromClient, err := websocketClient.ReadBytesFromClient() + assert.NoError(t, err) + assert.Equal(t, messageToServer, messageFromClient) + }) + t.Run("should detect a closed connection", func(t *testing.T) { + t.Run("before read", func(t *testing.T) { + _, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + defer connToClient.Close() + websocketClient.isClosedConnection = true + + assert.Eventually(t, func() bool { + _, err := websocketClient.ReadBytesFromClient() + return assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + }, 1*time.Second, 2*time.Millisecond) + }) + t.Run("when not wrapped", func(t *testing.T) { + t.Run("io.EOF", func(t *testing.T) { + connToServer, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + err := connToServer.Close() + require.NoError(t, err) + + _, err = websocketClient.ReadBytesFromClient() + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + t.Run("io: read/write on closed pipe", func(t *testing.T) { + connToClient := &FakeConn{} + connToClient.setReadReturns(0, io.ErrClosedPipe) + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + _, err := websocketClient.ReadBytesFromClient() + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + t.Run("unexpected EOF", func(t *testing.T) { + connToClient := &FakeConn{} + connToClient.setReadReturns(0, io.ErrUnexpectedEOF) + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + _, err := websocketClient.ReadBytesFromClient() + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + }) + + t.Run("when wrapped", func(t *testing.T) { + t.Run("io.EOF", func(t *testing.T) { + connToClient := &FakeConn{} + wrappedErr := fmt.Errorf("outside wrapper: %w", + fmt.Errorf("inner wrapper: %w", + io.EOF, + ), + ) + connToClient.setReadReturns(0, wrappedErr) + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + _, err := websocketClient.ReadBytesFromClient() + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + t.Run("io: read/write on closed pipe", func(t *testing.T) { + connToClient := &FakeConn{} + wrappedErr := fmt.Errorf("outside wrapper: %w", + fmt.Errorf("inner wrapper: %w", + io.ErrClosedPipe, + ), + ) + connToClient.setReadReturns(0, wrappedErr) + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + _, err := websocketClient.ReadBytesFromClient() + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + t.Run("unexpected EOF", func(t *testing.T) { + connToClient := &FakeConn{} + wrappedErr := fmt.Errorf("outside wrapper: %w", + fmt.Errorf("inner wrapper: %w", + io.ErrUnexpectedEOF, + ), + ) + connToClient.setReadReturns(0, wrappedErr) + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + _, err := websocketClient.ReadBytesFromClient() + assert.Equal(t, subscription.ErrTransportClientClosedConnection, err) + assert.True(t, websocketClient.isClosedConnection) + }) + }) + + }) +} + +func TestClient_IsConnected(t *testing.T) { + _, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + t.Run("should return true when a connection is established", func(t *testing.T) { + isConnected := websocketClient.IsConnected() + assert.True(t, isConnected) + }) + + t.Run("should return false when a connection is closed", func(t *testing.T) { + err := connToClient.Close() + require.NoError(t, err) + + websocketClient.isClosedConnection = true + + isConnected := websocketClient.IsConnected() + assert.False(t, isConnected) + }) +} + +func TestClient_Disconnect(t *testing.T) { + _, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + + t.Run("should disconnect and indicate a closed connection", func(t *testing.T) { + err := websocketClient.Disconnect() + assert.NoError(t, err) + assert.Equal(t, true, websocketClient.isClosedConnection) + }) +} + +func TestClient_DisconnectWithReason(t *testing.T) { + t.Run("disconnect with invalid reason", func(t *testing.T) { + connToServer, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + serverResponseChan := make(chan testServerWebsocketResponse) + + go readServerResponse(serverResponseChan, connToServer) + + go func() { + err := websocketClient.DisconnectWithReason( + "invalid reason", + ) + assert.NoError(t, err) + }() + + assert.Eventually(t, func() bool { + actualServerResult := <-serverResponseChan + assert.NoError(t, actualServerResult.err) + assert.Equal(t, ws.OpClose, actualServerResult.opCode) + assert.Equal(t, ws.StatusCode(4400), actualServerResult.statusCode) + assert.Equal(t, "unknown reason", actualServerResult.closeReason) + assert.Equal(t, false, websocketClient.IsConnected()) + return true + }, 1*time.Second, 2*time.Millisecond) + }) + + t.Run("disconnect with reason", func(t *testing.T) { + connToServer, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + serverResponseChan := make(chan testServerWebsocketResponse) + + go readServerResponse(serverResponseChan, connToServer) + + go func() { + err := websocketClient.DisconnectWithReason( + NewCloseReason(4400, "error occurred"), + ) + assert.NoError(t, err) + }() + + assert.Eventually(t, func() bool { + actualServerResult := <-serverResponseChan + assert.NoError(t, actualServerResult.err) + assert.Equal(t, ws.OpClose, actualServerResult.opCode) + assert.Equal(t, ws.StatusCode(4400), actualServerResult.statusCode) + assert.Equal(t, "error occurred", actualServerResult.closeReason) + assert.Equal(t, false, websocketClient.IsConnected()) + return true + }, 1*time.Second, 2*time.Millisecond) + }) + + t.Run("disconnect with compiled reason", func(t *testing.T) { + connToServer, connToClient := net.Pipe() + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + serverResponseChan := make(chan testServerWebsocketResponse) + + go readServerResponse(serverResponseChan, connToServer) + + go func() { + err := websocketClient.DisconnectWithReason( + CompiledCloseReasonNormal, + ) + assert.NoError(t, err) + }() + + assert.Eventually(t, func() bool { + actualServerResult := <-serverResponseChan + assert.NoError(t, actualServerResult.err) + assert.Equal(t, ws.OpClose, actualServerResult.opCode) + assert.Equal(t, ws.StatusCode(1000), actualServerResult.statusCode) + assert.Equal(t, "Normal Closure", actualServerResult.closeReason) + assert.Equal(t, false, websocketClient.IsConnected()) + return true + }, 1*time.Second, 2*time.Millisecond) + }) +} + +func TestClient_isClosedConnectionError(t *testing.T) { + _, connToClient := net.Pipe() + + t.Run("should not close connection when it is not a closed connection error", func(t *testing.T) { + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + require.False(t, websocketClient.isClosedConnection) + + isClosedConnectionError := websocketClient.isClosedConnectionError(errors.New("no closed connection err")) + assert.False(t, isClosedConnectionError) + }) + + t.Run("should close connection when it is a closed connection error", func(t *testing.T) { + websocketClient := NewClient(abstractlogger.NoopLogger, connToClient) + require.False(t, websocketClient.isClosedConnection) + + isClosedConnectionError := websocketClient.isClosedConnectionError(wsutil.ClosedError{}) + assert.True(t, isClosedConnectionError) + websocketClient.isClosedConnection = false + + require.False(t, websocketClient.isClosedConnection) + isClosedConnectionError = websocketClient.isClosedConnectionError(wsutil.ClosedError{ + Code: ws.StatusNormalClosure, + Reason: "Normal Closure", + }) + assert.True(t, isClosedConnectionError) + }) +} + +type TestClient struct { + connectionMutex *sync.RWMutex + messageFromClient chan []byte + messageToClient chan []byte + isConnected bool + shouldFail bool +} + +func NewTestClient(shouldFail bool) *TestClient { + return &TestClient{ + connectionMutex: &sync.RWMutex{}, + messageFromClient: make(chan []byte, 1), + messageToClient: make(chan []byte, 1), + isConnected: true, + shouldFail: shouldFail, + } +} + +func (t *TestClient) ReadBytesFromClient() ([]byte, error) { + if t.shouldFail { + return nil, errors.New("shouldFail is true") + } + return <-t.messageFromClient, nil +} + +func (t *TestClient) WriteBytesToClient(message []byte) error { + if t.shouldFail { + return errors.New("shouldFail is true") + } + t.messageToClient <- message + return nil +} + +func (t *TestClient) IsConnected() bool { + t.connectionMutex.RLock() + defer t.connectionMutex.RUnlock() + return t.isConnected +} + +func (t *TestClient) Disconnect() error { + t.connectionMutex.Lock() + defer t.connectionMutex.Unlock() + t.isConnected = false + return nil +} + +func (t *TestClient) DisconnectWithReason(reason interface{}) error { + t.connectionMutex.Lock() + defer t.connectionMutex.Unlock() + t.isConnected = false + return nil +} + +func (t *TestClient) readMessageToClient() []byte { + return <-t.messageToClient +} + +func (t *TestClient) writeMessageFromClient(message []byte) { + t.messageFromClient <- message +} + +type FakeConn struct { + readReturnN int + readReturnErr error + writeReturnN int + writeReturnErr error +} + +func (f *FakeConn) setReadReturns(n int, err error) { + f.readReturnN = n + f.readReturnErr = err +} + +func (f *FakeConn) Read(b []byte) (n int, err error) { + return f.readReturnN, f.readReturnErr +} + +func (f *FakeConn) setWriteReturns(n int, err error) { + f.writeReturnN = n + f.writeReturnErr = err +} + +func (f *FakeConn) Write(b []byte) (n int, err error) { + return f.writeReturnN, f.writeReturnErr +} + +func (f *FakeConn) Close() error { + panic("implement me") +} + +func (f *FakeConn) LocalAddr() net.Addr { + panic("implement me") +} + +func (f *FakeConn) RemoteAddr() net.Addr { + panic("implement me") +} + +func (f *FakeConn) SetDeadline(t time.Time) error { + panic("implement me") +} + +func (f *FakeConn) SetReadDeadline(t time.Time) error { + panic("implement me") +} + +func (f *FakeConn) SetWriteDeadline(t time.Time) error { + panic("implement me") +} + +func readServerResponse(responseChan chan testServerWebsocketResponse, connToServer net.Conn) { + var statusCode ws.StatusCode + var closeReason string + frame, err := ws.ReadFrame(connToServer) + if err == nil { + statusCode, closeReason = ws.ParseCloseFrameData(frame.Payload) + } + + response := testServerWebsocketResponse{ + data: frame.Payload, + opCode: frame.Header.OpCode, + statusCode: statusCode, + closeReason: closeReason, + err: err, + } + + responseChan <- response +} diff --git a/v2/pkg/subscription/websocket/engine_mock_test.go b/v2/pkg/subscription/websocket/engine_mock_test.go new file mode 100644 index 000000000..da14ffc45 --- /dev/null +++ b/v2/pkg/subscription/websocket/engine_mock_test.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Engine) + +// Package websocket is a generated GoMock package. +package websocket + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + subscription "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +// MockEngine is a mock of Engine interface. +type MockEngine struct { + ctrl *gomock.Controller + recorder *MockEngineMockRecorder +} + +// MockEngineMockRecorder is the mock recorder for MockEngine. +type MockEngineMockRecorder struct { + mock *MockEngine +} + +// NewMockEngine creates a new mock instance. +func NewMockEngine(ctrl *gomock.Controller) *MockEngine { + mock := &MockEngine{ctrl: ctrl} + mock.recorder = &MockEngineMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEngine) EXPECT() *MockEngineMockRecorder { + return m.recorder +} + +// StartOperation mocks base method. +func (m *MockEngine) StartOperation(arg0 context.Context, arg1 string, arg2 []byte, arg3 subscription.EventHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartOperation", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// StartOperation indicates an expected call of StartOperation. +func (mr *MockEngineMockRecorder) StartOperation(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartOperation", reflect.TypeOf((*MockEngine)(nil).StartOperation), arg0, arg1, arg2, arg3) +} + +// StopSubscription mocks base method. +func (m *MockEngine) StopSubscription(arg0 string, arg1 subscription.EventHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StopSubscription", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// StopSubscription indicates an expected call of StopSubscription. +func (mr *MockEngineMockRecorder) StopSubscription(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopSubscription", reflect.TypeOf((*MockEngine)(nil).StopSubscription), arg0, arg1) +} + +// TerminateAllSubscriptions mocks base method. +func (m *MockEngine) TerminateAllSubscriptions(arg0 subscription.EventHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TerminateAllSubscriptions", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// TerminateAllSubscriptions indicates an expected call of TerminateAllSubscriptions. +func (mr *MockEngineMockRecorder) TerminateAllSubscriptions(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateAllSubscriptions", reflect.TypeOf((*MockEngine)(nil).TerminateAllSubscriptions), arg0) +} diff --git a/v2/pkg/subscription/websocket/handler.go b/v2/pkg/subscription/websocket/handler.go new file mode 100644 index 000000000..226b8753e --- /dev/null +++ b/v2/pkg/subscription/websocket/handler.go @@ -0,0 +1,228 @@ +package websocket + +import ( + "context" + "net" + "net/http" + "time" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +const ( + DefaultConnectionInitTimeOut = "15s" + + HeaderSecWebSocketProtocol = "Sec-WebSocket-Protocol" +) + +// Protocol defines the protocol names as type. +type Protocol string + +const ( + ProtocolUndefined Protocol = "" + ProtocolGraphQLWS Protocol = "graphql-ws" + ProtocolGraphQLTransportWS Protocol = "graphql-transport-ws" +) + +var DefaultProtocol = ProtocolGraphQLTransportWS + +// HandleOptions can be used to pass options to the websocket handler. +type HandleOptions struct { + Logger abstractlogger.Logger + Protocol Protocol + WebSocketInitFunc InitFunc + CustomClient subscription.TransportClient + CustomKeepAliveInterval time.Duration + CustomSubscriptionUpdateInterval time.Duration + CustomConnectionInitTimeOut time.Duration + CustomReadErrorTimeOut time.Duration + CustomSubscriptionEngine subscription.Engine +} + +// HandleOptionFunc can be used to define option functions. +type HandleOptionFunc func(opts *HandleOptions) + +// WithLogger is a function that sets a logger for the websocket handler. +func WithLogger(logger abstractlogger.Logger) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.Logger = logger + } +} + +// WithInitFunc is a function that sets the init function for the websocket handler. +func WithInitFunc(initFunc InitFunc) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.WebSocketInitFunc = initFunc + } +} + +// WithCustomClient is a function that set a custom transport client for the websocket handler. +func WithCustomClient(client subscription.TransportClient) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.CustomClient = client + } +} + +// WithCustomKeepAliveInterval is a function that sets a custom keep-alive interval for the websocket handler. +func WithCustomKeepAliveInterval(keepAliveInterval time.Duration) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.CustomKeepAliveInterval = keepAliveInterval + } +} + +// WithCustomSubscriptionUpdateInterval is a function that sets a custom subscription update interval for the +// websocket handler. +func WithCustomSubscriptionUpdateInterval(subscriptionUpdateInterval time.Duration) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.CustomSubscriptionUpdateInterval = subscriptionUpdateInterval + } +} + +// WithCustomConnectionInitTimeOut is a function that sets a custom connection init time out. +func WithCustomConnectionInitTimeOut(connectionInitTimeOut time.Duration) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.CustomConnectionInitTimeOut = connectionInitTimeOut + } +} + +// WithCustomReadErrorTimeOut is a function that sets a custom read error time out for the +// websocket handler. +func WithCustomReadErrorTimeOut(readErrorTimeOut time.Duration) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.CustomReadErrorTimeOut = readErrorTimeOut + } +} + +// WithCustomSubscriptionEngine is a function that sets a custom subscription engine for the websocket handler. +func WithCustomSubscriptionEngine(subscriptionEngine subscription.Engine) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.CustomSubscriptionEngine = subscriptionEngine + } +} + +// WithProtocol is a function that sets the protocol. +func WithProtocol(protocol Protocol) HandleOptionFunc { + return func(opts *HandleOptions) { + opts.Protocol = protocol + } +} + +// WithProtocolFromRequestHeaders is a function that sets the protocol based on the request headers. +// It fallbacks to the DefaultProtocol if the header can't be found, the value is invalid or no request +// was provided. +func WithProtocolFromRequestHeaders(req *http.Request) HandleOptionFunc { + return func(opts *HandleOptions) { + if req == nil { + opts.Protocol = DefaultProtocol + return + } + + protocolHeaderValue := req.Header.Get(HeaderSecWebSocketProtocol) + switch Protocol(protocolHeaderValue) { + case ProtocolGraphQLWS: + opts.Protocol = ProtocolGraphQLWS + case ProtocolGraphQLTransportWS: + opts.Protocol = ProtocolGraphQLTransportWS + default: + opts.Protocol = DefaultProtocol + } + } +} + +// Handle will handle the websocket subscription. It can take optional option functions to customize the handler. +// behavior. By default, it uses the 'graphql-transport-ws' protocol. +func Handle(done chan bool, errChan chan error, conn net.Conn, executorPool subscription.ExecutorPool, options ...HandleOptionFunc) { + definedOptions := HandleOptions{ + Logger: abstractlogger.Noop{}, + Protocol: DefaultProtocol, + } + + for _, optionFunc := range options { + optionFunc(&definedOptions) + } + + HandleWithOptions(done, errChan, conn, executorPool, definedOptions) +} + +// HandleWithOptions will handle the websocket connection. It requires an option struct to define the behavior. +func HandleWithOptions(done chan bool, errChan chan error, conn net.Conn, executorPool subscription.ExecutorPool, options HandleOptions) { + // Use noop logger to prevent nil pointers if none was provided + if options.Logger == nil { + options.Logger = abstractlogger.Noop{} + } + + defer func() { + if err := conn.Close(); err != nil { + options.Logger.Error("websocket.HandleWithOptions: on deferred closing connection", + abstractlogger.String("message", "could not close connection to client"), + abstractlogger.Error(err), + ) + } + }() + + var client subscription.TransportClient + if options.CustomClient != nil { + client = options.CustomClient + } else { + client = NewClient(options.Logger, conn) + } + + protocolHandler, err := createProtocolHandler(options, client) + if err != nil { + options.Logger.Error("websocket.HandleWithOptions: on protocol handler creation", + abstractlogger.String("message", "could not create protocol handler"), + abstractlogger.String("protocol", string(DefaultProtocol)), + abstractlogger.Error(err), + ) + + errChan <- err + return + } + + subscriptionHandler, err := subscription.NewUniversalProtocolHandlerWithOptions(client, protocolHandler, executorPool, subscription.UniversalProtocolHandlerOptions{ + Logger: options.Logger, + CustomSubscriptionUpdateInterval: options.CustomSubscriptionUpdateInterval, + CustomReadErrorTimeOut: options.CustomReadErrorTimeOut, + CustomEngine: options.CustomSubscriptionEngine, + }) + if err != nil { + options.Logger.Error("websocket.HandleWithOptions: on subscription handler creation", + abstractlogger.String("message", "could not create subscription handler"), + abstractlogger.String("protocol", string(DefaultProtocol)), + abstractlogger.Error(err), + ) + + errChan <- err + return + } + + close(done) + subscriptionHandler.Handle(context.Background()) // Blocking +} + +func createProtocolHandler(handleOptions HandleOptions, client subscription.TransportClient) (protocolHandler subscription.Protocol, err error) { + protocol := handleOptions.Protocol + if protocol == ProtocolUndefined { + protocol = DefaultProtocol + } + + switch protocol { + case ProtocolGraphQLWS: + protocolHandler, err = NewProtocolGraphQLWSHandlerWithOptions(client, ProtocolGraphQLWSHandlerOptions{ + Logger: handleOptions.Logger, + WebSocketInitFunc: handleOptions.WebSocketInitFunc, + CustomKeepAliveInterval: handleOptions.CustomKeepAliveInterval, + }) + default: + protocolHandler, err = NewProtocolGraphQLTransportWSHandlerWithOptions(client, ProtocolGraphQLTransportWSHandlerOptions{ + Logger: handleOptions.Logger, + WebSocketInitFunc: handleOptions.WebSocketInitFunc, + CustomKeepAliveInterval: handleOptions.CustomKeepAliveInterval, + CustomInitTimeOutDuration: handleOptions.CustomConnectionInitTimeOut, + }) + } + + return protocolHandler, err +} diff --git a/v2/pkg/subscription/websocket/handler_test.go b/v2/pkg/subscription/websocket/handler_test.go new file mode 100644 index 000000000..cc4fc8a84 --- /dev/null +++ b/v2/pkg/subscription/websocket/handler_test.go @@ -0,0 +1,344 @@ +package websocket + +import ( + "bytes" + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/pkg/engine/datasource/graphql_datasource" + "github.com/wundergraph/graphql-go-tools/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/pkg/testing/subscriptiontesting" +) + +func TestHandleWithOptions(t *testing.T) { + t.Run("should handle protocol graphql-ws", func(t *testing.T) { + chatServer := httptest.NewServer(subscriptiontesting.ChatGraphQLEndpointHandler()) + defer chatServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + executorPoolV2 := setupExecutorPoolV2(t, ctx, chatServer.URL, nil) + serverConn, _ := net.Pipe() + testClient := NewTestClient(false) + + done := make(chan bool) + errChan := make(chan error) + go Handle( + done, + errChan, + serverConn, + executorPoolV2, + WithProtocol(ProtocolGraphQLWS), + WithCustomClient(testClient), + WithCustomSubscriptionUpdateInterval(50*time.Millisecond), + WithCustomKeepAliveInterval(3600*time.Second), // keep_alive should not intervene with our tests, so make it high + ) + + require.Eventually(t, func() bool { + <-done + return true + }, 1*time.Second, 2*time.Millisecond) + + testClient.writeMessageFromClient([]byte(`{"type":"connection_init"}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"type":"connection_ack"}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 1*time.Second, 2*time.Millisecond, "never satisfied on connection_init") + + testClient.writeMessageFromClient([]byte(`{"id":"1","type":"start","payload":{"query":"{ room(name:\"#my_room\") { name } }"}}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"id":"1","type":"data","payload":{"data":{"room":{"name":"#my_room"}}}}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + expectedMessage = []byte(`{"id":"1","type":"complete"}`) + actualMessage = testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 2*time.Second, 2*time.Millisecond, "never satisfied on start non-subscription") + + testClient.writeMessageFromClient([]byte(`{"id":"2","type":"start","payload":{"query":"subscription { messageAdded(roomName:\"#my_room\") { text } }"}}`)) + time.Sleep(15 * time.Millisecond) + testClient.writeMessageFromClient([]byte(`{"id":"3","type":"start","payload":{"query":"mutation { post(text: \"hello\", username: \"me\", roomName: \"#my_room\") { text } }"}}`)) + assert.Eventually(t, func() bool { + expectedMessages := []string{ + `{"id":"3","type":"data","payload":{"data":{"post":{"text":"hello"}}}}`, + `{"id":"3","type":"complete"}`, + `{"id":"2","type":"data","payload":{"data":{"messageAdded":{"text":"hello"}}}}`, + } + actualMessage := testClient.readMessageToClient() + assert.Contains(t, expectedMessages, string(actualMessage)) + actualMessage = testClient.readMessageToClient() + assert.Contains(t, expectedMessages, string(actualMessage)) + actualMessage = testClient.readMessageToClient() + assert.Contains(t, expectedMessages, string(actualMessage)) + return true + }, 2*time.Second, 2*time.Millisecond, "never satisfied on start subscription") + + testClient.writeMessageFromClient([]byte(`{"id":"2","type":"stop"}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"id":"2","type":"complete"}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 2*time.Second, 2*time.Millisecond, "never satisfied on stop subscription") + }) + + t.Run("should handle protocol graphql-transport-ws", func(t *testing.T) { + chatServer := httptest.NewServer(subscriptiontesting.ChatGraphQLEndpointHandler()) + defer chatServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + executorPoolV2 := setupExecutorPoolV2(t, ctx, chatServer.URL, nil) + serverConn, _ := net.Pipe() + testClient := NewTestClient(false) + + done := make(chan bool) + errChan := make(chan error) + go Handle( + done, + errChan, + serverConn, + executorPoolV2, + WithProtocol(ProtocolGraphQLTransportWS), + WithCustomClient(testClient), + WithCustomSubscriptionUpdateInterval(50*time.Millisecond), + WithCustomKeepAliveInterval(3600*time.Second), // keep_alive should not intervene with our tests, so make it high + ) + + require.Eventually(t, func() bool { + <-done + return true + }, 1*time.Second, 2*time.Millisecond) + + testClient.writeMessageFromClient([]byte(`{"type":"connection_init"}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"type":"connection_ack"}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 1*time.Second, 2*time.Millisecond, "never satisfied on connection_init") + + testClient.writeMessageFromClient([]byte(`{"id":"1","type":"subscribe","payload":{"query":"{ room(name:\"#my_room\") { name } }"}}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"id":"1","type":"next","payload":{"data":{"room":{"name":"#my_room"}}}}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + expectedMessage = []byte(`{"id":"1","type":"complete"}`) + actualMessage = testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 2*time.Second, 2*time.Millisecond, "never satisfied on start non-subscription") + + testClient.writeMessageFromClient([]byte(`{"id":"2","type":"subscribe","payload":{"query":"subscription { messageAdded(roomName:\"#my_room\") { text } }"}}`)) + time.Sleep(15 * time.Millisecond) + testClient.writeMessageFromClient([]byte(`{"id":"3","type":"subscribe","payload":{"query":"mutation { post(text: \"hello\", username: \"me\", roomName: \"#my_room\") { text } }"}}`)) + assert.Eventually(t, func() bool { + expectedMessages := []string{ + `{"id":"3","type":"next","payload":{"data":{"post":{"text":"hello"}}}}`, + `{"id":"3","type":"complete"}`, + `{"id":"2","type":"next","payload":{"data":{"messageAdded":{"text":"hello"}}}}`, + } + actualMessage := testClient.readMessageToClient() + assert.Contains(t, expectedMessages, string(actualMessage)) + actualMessage = testClient.readMessageToClient() + assert.Contains(t, expectedMessages, string(actualMessage)) + actualMessage = testClient.readMessageToClient() + assert.Contains(t, expectedMessages, string(actualMessage)) + return true + }, 2*time.Second, 2*time.Millisecond, "never satisfied on start subscription") + + testClient.writeMessageFromClient([]byte(`{"id":"2","type":"complete"}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"id":"2","type":"complete"}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 2*time.Second, 2*time.Millisecond, "never satisfied on stop subscription") + }) + + t.Run("should handle on before start error", func(t *testing.T) { + chatServer := httptest.NewServer(subscriptiontesting.ChatGraphQLEndpointHandler()) + defer chatServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + executorPoolV2 := setupExecutorPoolV2(t, ctx, chatServer.URL, &FailingOnBeforeStartHook{}) + serverConn, _ := net.Pipe() + testClient := NewTestClient(false) + + done := make(chan bool) + errChan := make(chan error) + go Handle( + done, + errChan, + serverConn, + executorPoolV2, + WithProtocol(ProtocolGraphQLTransportWS), + WithCustomClient(testClient), + WithCustomSubscriptionUpdateInterval(50*time.Millisecond), + WithCustomKeepAliveInterval(3600*time.Second), // keep_alive should not intervene with our tests, so make it high + ) + + require.Eventually(t, func() bool { + <-done + return true + }, 1*time.Second, 2*time.Millisecond) + + testClient.writeMessageFromClient([]byte(`{"type":"connection_init"}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"type":"connection_ack"}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 1*time.Second, 2*time.Millisecond, "never satisfied on connection_init") + + testClient.writeMessageFromClient([]byte(`{"id":"1","type":"subscribe","payload":{"query":"{ room(name:\"#my_room\") { name } }"}}`)) + assert.Eventually(t, func() bool { + expectedMessage := []byte(`{"id":"1","type":"error","payload":[{"message":"on before start error"}]}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + return true + }, 2*time.Second, 2*time.Millisecond, "never satisfied on before start error") + }) +} + +func TestWithProtocolFromRequestHeaders(t *testing.T) { + runTest := func(headerKey string, headerValue string, expectedProtocol Protocol) func(t *testing.T) { + return func(t *testing.T) { + request, err := http.NewRequest("", "", nil) + require.NoError(t, err) + request.Header.Set(headerKey, headerValue) + + options := &HandleOptions{} + optionFunc := WithProtocolFromRequestHeaders(request) + optionFunc(options) + + assert.Equal(t, expectedProtocol, options.Protocol) + } + } + + t.Run("should detect graphql-ws", runTest(HeaderSecWebSocketProtocol, "graphql-ws", ProtocolGraphQLWS)) + t.Run("should detect graphql-transport-ws", runTest(HeaderSecWebSocketProtocol, "graphql-transport-ws", ProtocolGraphQLTransportWS)) + t.Run("should fallback to default protocol", runTest(HeaderSecWebSocketProtocol, "something-else", DefaultProtocol)) + t.Run("should fallback to default protocol when header is missing", runTest("Different-Header-Key", "missing-header", DefaultProtocol)) + t.Run("should fallback to default protocol when request is nil", func(t *testing.T) { + options := &HandleOptions{} + optionFunc := WithProtocolFromRequestHeaders(nil) + optionFunc(options) + assert.Equal(t, DefaultProtocol, options.Protocol) + }) +} + +func setupExecutorPoolV2(t *testing.T, ctx context.Context, chatServerURL string, onBeforeStartHook graphql.WebsocketBeforeStartHook) *subscription.ExecutorV2Pool { + chatSchemaBytes, err := subscriptiontesting.LoadSchemaFromExamplesDirectoryWithinPkg() + require.NoError(t, err) + + chatSchema, err := graphql.NewSchemaFromReader(bytes.NewBuffer(chatSchemaBytes)) + require.NoError(t, err) + + engineConf := graphql.NewEngineV2Configuration(chatSchema) + engineConf.SetWebsocketBeforeStartHook(onBeforeStartHook) + engineConf.SetDataSources([]plan.DataSourceConfiguration{ + { + RootNodes: []plan.TypeField{ + {TypeName: "Query", FieldNames: []string{"room"}}, + {TypeName: "Mutation", FieldNames: []string{"post"}}, + {TypeName: "Subscription", FieldNames: []string{"messageAdded"}}, + }, + ChildNodes: []plan.TypeField{ + {TypeName: "Chatroom", FieldNames: []string{"name", "messages"}}, + {TypeName: "Message", FieldNames: []string{"text", "createdBy"}}, + }, + Factory: &graphql_datasource.Factory{ + HTTPClient: httpclient.DefaultNetHttpClient, + }, + Custom: graphql_datasource.ConfigJson(graphql_datasource.Configuration{ + Fetch: graphql_datasource.FetchConfiguration{ + URL: chatServerURL, + Method: http.MethodPost, + Header: nil, + }, + Subscription: graphql_datasource.SubscriptionConfiguration{ + URL: chatServerURL, + }, + }), + }, + }) + engineConf.SetFieldConfigurations([]plan.FieldConfiguration{ + { + TypeName: "Query", + FieldName: "room", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "name", + SourceType: plan.FieldArgumentSource, + }, + }, + }, + { + TypeName: "Mutation", + FieldName: "post", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "roomName", + SourceType: plan.FieldArgumentSource, + }, + { + Name: "username", + SourceType: plan.FieldArgumentSource, + }, + { + Name: "text", + SourceType: plan.FieldArgumentSource, + }, + }, + }, + { + TypeName: "Subscription", + FieldName: "messageAdded", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "roomName", + SourceType: plan.FieldArgumentSource, + }, + }, + }, + }) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost:8080", nil) + require.NoError(t, err) + + req.Header.Set("X-Other-Key", "x-other-value") + + initCtx := subscription.NewInitialHttpRequestContext(req) + + engine, err := graphql.NewExecutionEngineV2(initCtx, abstractlogger.NoopLogger, engineConf) + require.NoError(t, err) + + executorPool := subscription.NewExecutorV2Pool(engine, ctx) + return executorPool +} + +type FailingOnBeforeStartHook struct{} + +func (f *FailingOnBeforeStartHook) OnBeforeStart(reqCtx context.Context, operation *graphql.Request) error { + return errors.New("on before start error") +} diff --git a/v2/pkg/subscription/websocket/init.go b/v2/pkg/subscription/websocket/init.go new file mode 100644 index 000000000..224695479 --- /dev/null +++ b/v2/pkg/subscription/websocket/init.go @@ -0,0 +1,47 @@ +package websocket + +import ( + "context" + "encoding/json" +) + +// InitFunc is called when the server receives connection init message from the client. +// This can be used to check initial payload to see whether to accept the websocket connection. +type InitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) + +// InitPayload is a structure that is parsed from the websocket init message payload. +type InitPayload json.RawMessage + +// GetString safely gets a string value from the payload. It returns an empty string if the +// payload is nil or the value isn't set. +func (p InitPayload) GetString(key string) string { + if p == nil { + return "" + } + + var payload map[string]interface{} + if err := json.Unmarshal(p, &payload); err != nil { + return "" + } + + if value, ok := payload[key]; ok { + res, _ := value.(string) + return res + } + + return "" +} + +// Authorization is a shorthand for getting the Authorization header from the +// payload. +func (p InitPayload) Authorization() string { + if value := p.GetString("Authorization"); value != "" { + return value + } + + if value := p.GetString("authorization"); value != "" { + return value + } + + return "" +} diff --git a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go new file mode 100644 index 000000000..7ecc17b56 --- /dev/null +++ b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go @@ -0,0 +1,524 @@ +package websocket + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +// GraphQLTransportWSMessageType is a type that defines graphql-transport-ws message type names. +type GraphQLTransportWSMessageType string + +const ( + GraphQLTransportWSMessageTypeConnectionInit GraphQLTransportWSMessageType = "connection_init" + GraphQLTransportWSMessageTypeConnectionAck GraphQLTransportWSMessageType = "connection_ack" + GraphQLTransportWSMessageTypePing GraphQLTransportWSMessageType = "ping" + GraphQLTransportWSMessageTypePong GraphQLTransportWSMessageType = "pong" + GraphQLTransportWSMessageTypeSubscribe GraphQLTransportWSMessageType = "subscribe" + GraphQLTransportWSMessageTypeNext GraphQLTransportWSMessageType = "next" + GraphQLTransportWSMessageTypeError GraphQLTransportWSMessageType = "error" + GraphQLTransportWSMessageTypeComplete GraphQLTransportWSMessageType = "complete" +) + +const ( + GraphQLTransportWSHeartbeatPayload = `{"type":"heartbeat"}` +) + +// GraphQLTransportWSMessage is a struct that can be (de)serialized to graphql-transport-ws message format. +type GraphQLTransportWSMessage struct { + Id string `json:"id,omitempty"` + Type GraphQLTransportWSMessageType `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +// GraphQLTransportWSMessageSubscribePayload is a struct that can be (de)serialized to graphql-transport-ws message payload format. +type GraphQLTransportWSMessageSubscribePayload struct { + OperationName string `json:"operationName,omitempty"` + Query string `json:"query"` + Variables json.RawMessage `json:"variables,omitempty"` + Extensions json.RawMessage `json:"extensions,omitempty"` +} + +// GraphQLTransportWSMessageReader can be used to read graphql-transport-ws messages. +type GraphQLTransportWSMessageReader struct { + logger abstractlogger.Logger +} + +// Read deserializes a byte slice to the GraphQLTransportWSMessage struct. +func (g *GraphQLTransportWSMessageReader) Read(data []byte) (*GraphQLTransportWSMessage, error) { + var message GraphQLTransportWSMessage + err := json.Unmarshal(data, &message) + if err != nil { + g.logger.Error("websocket.GraphQLTransportWSMessageReader.Read: on json unmarshal", + abstractlogger.Error(err), + abstractlogger.ByteString("data", data), + ) + + return nil, err + } + return &message, nil +} + +// DeserializeSubscribePayload deserialized the subscribe payload from a graphql-transport-ws message. +func (g *GraphQLTransportWSMessageReader) DeserializeSubscribePayload(message *GraphQLTransportWSMessage) (*GraphQLTransportWSMessageSubscribePayload, error) { + var deserializedPayload GraphQLTransportWSMessageSubscribePayload + err := json.Unmarshal(message.Payload, &deserializedPayload) + if err != nil { + g.logger.Error("websocket.GraphQLTransportWSMessageReader.DeserializeSubscribePayload: on subscribe payload deserialization", + abstractlogger.Error(err), + abstractlogger.ByteString("payload", message.Payload), + ) + return nil, err + } + + return &deserializedPayload, nil +} + +// GraphQLTransportWSMessageWriter can be used to write graphql-transport-ws messages to a transport client. +type GraphQLTransportWSMessageWriter struct { + logger abstractlogger.Logger + mu *sync.Mutex + Client subscription.TransportClient +} + +// WriteConnectionAck writes a message of type 'connection_ack' to the transport client. +func (g *GraphQLTransportWSMessageWriter) WriteConnectionAck() error { + message := &GraphQLTransportWSMessage{ + Type: GraphQLTransportWSMessageTypeConnectionAck, + } + return g.write(message) +} + +// WritePing writes a message of type 'ping' to the transport client. Payload is optional. +func (g *GraphQLTransportWSMessageWriter) WritePing(payload []byte) error { + message := &GraphQLTransportWSMessage{ + Type: GraphQLTransportWSMessageTypePing, + Payload: payload, + } + return g.write(message) +} + +// WritePong writes a message of type 'pong' to the transport client. Payload is optional. +func (g *GraphQLTransportWSMessageWriter) WritePong(payload []byte) error { + message := &GraphQLTransportWSMessage{ + Type: GraphQLTransportWSMessageTypePong, + Payload: payload, + } + return g.write(message) +} + +// WriteNext writes a message of type 'next' to the transport client including the execution result as payload. +func (g *GraphQLTransportWSMessageWriter) WriteNext(id string, executionResult []byte) error { + message := &GraphQLTransportWSMessage{ + Id: id, + Type: GraphQLTransportWSMessageTypeNext, + Payload: executionResult, + } + return g.write(message) +} + +// WriteError writes a message of type 'error' to the transport client including the graphql errors as payload. +func (g *GraphQLTransportWSMessageWriter) WriteError(id string, graphqlErrors graphql.RequestErrors) error { + payloadBytes, err := json.Marshal(graphqlErrors) + if err != nil { + return err + } + message := &GraphQLTransportWSMessage{ + Id: id, + Type: GraphQLTransportWSMessageTypeError, + Payload: payloadBytes, + } + return g.write(message) +} + +// WriteComplete writes a message of type 'complete' to the transport client. +func (g *GraphQLTransportWSMessageWriter) WriteComplete(id string) error { + message := &GraphQLTransportWSMessage{ + Id: id, + Type: GraphQLTransportWSMessageTypeComplete, + } + return g.write(message) +} + +func (g *GraphQLTransportWSMessageWriter) write(message *GraphQLTransportWSMessage) error { + jsonData, err := json.Marshal(message) + if err != nil { + g.logger.Error("websocket.GraphQLTransportWSMessageWriter.write: on json marshal", + abstractlogger.Error(err), + abstractlogger.String("id", message.Id), + abstractlogger.String("type", string(message.Type)), + abstractlogger.Any("payload", message.Payload), + ) + return err + } + g.mu.Lock() + defer g.mu.Unlock() + return g.Client.WriteBytesToClient(jsonData) +} + +// GraphQLTransportWSEventHandler can be used to handle subscription events and forward them to a GraphQLTransportWSMessageWriter. +type GraphQLTransportWSEventHandler struct { + logger abstractlogger.Logger + Writer GraphQLTransportWSMessageWriter + OnConnectionOpened func() +} + +// Emit is an implementation of subscription.EventHandler. It forwards some events to the HandleWriteEvent. +func (g *GraphQLTransportWSEventHandler) Emit(eventType subscription.EventType, id string, data []byte, err error) { + messageType := GraphQLTransportWSMessageType("") + switch eventType { + case subscription.EventTypeOnSubscriptionCompleted: + messageType = GraphQLTransportWSMessageTypeComplete + case subscription.EventTypeOnSubscriptionData: + messageType = GraphQLTransportWSMessageTypeNext + case subscription.EventTypeOnNonSubscriptionExecutionResult: + g.HandleWriteEvent(GraphQLTransportWSMessageTypeNext, id, data, err) + g.HandleWriteEvent(GraphQLTransportWSMessageTypeComplete, id, data, err) + return + case subscription.EventTypeOnError: + messageType = GraphQLTransportWSMessageTypeError + case subscription.EventTypeOnConnectionOpened: + if g.OnConnectionOpened != nil { + g.OnConnectionOpened() + } + return + case subscription.EventTypeOnDuplicatedSubscriberID: + err = g.Writer.Client.DisconnectWithReason( + NewCloseReason(4409, fmt.Sprintf("Subscriber for %s already exists", id)), + ) + + if err != nil { + g.logger.Error("websocket.GraphQLTransportWSEventHandler.Emit: on duplicate subscriber id handling", + abstractlogger.Error(err), + abstractlogger.String("id", id), + abstractlogger.String("type", string(messageType)), + abstractlogger.ByteString("payload", data), + ) + } + return + default: + return + } + g.HandleWriteEvent(messageType, id, data, err) +} + +// HandleWriteEvent forwards messages to the underlying writer. +func (g *GraphQLTransportWSEventHandler) HandleWriteEvent(messageType GraphQLTransportWSMessageType, id string, data []byte, providedErr error) { + var err error + switch messageType { + case GraphQLTransportWSMessageTypeComplete: + err = g.Writer.WriteComplete(id) + case GraphQLTransportWSMessageTypeNext: + err = g.Writer.WriteNext(id, data) + case GraphQLTransportWSMessageTypeError: + err = g.Writer.WriteError(id, graphql.RequestErrorsFromError(providedErr)) + case GraphQLTransportWSMessageTypeConnectionAck: + err = g.Writer.WriteConnectionAck() + case GraphQLTransportWSMessageTypePing: + err = g.Writer.WritePing(data) + case GraphQLTransportWSMessageTypePong: + err = g.Writer.WritePong(data) + default: + g.logger.Warn("websocket.GraphQLTransportWSEventHandler.HandleWriteEvent: on write event handling with unexpected message type", + abstractlogger.Error(err), + abstractlogger.String("id", id), + abstractlogger.String("type", string(messageType)), + abstractlogger.ByteString("payload", data), + abstractlogger.Error(providedErr), + ) + err = g.Writer.Client.DisconnectWithReason( + NewCloseReason( + 4400, + fmt.Sprintf("invalid type '%s'", string(messageType)), + ), + ) + if err != nil { + g.logger.Error("websocket.GraphQLTransportWSEventHandler.HandleWriteEvent: after disconnecting on write event handling with unexpected message type", + abstractlogger.Error(err), + abstractlogger.String("id", id), + abstractlogger.String("type", string(messageType)), + abstractlogger.ByteString("payload", data), + ) + } + return + } + if err != nil { + g.logger.Error("websocket.GraphQLTransportWSEventHandler.HandleWriteEvent: on write event handling", + abstractlogger.Error(err), + abstractlogger.String("id", id), + abstractlogger.String("type", string(messageType)), + abstractlogger.ByteString("payload", data), + abstractlogger.Error(providedErr), + ) + } +} + +// ProtocolGraphQLTransportWSHandlerOptions can be used to provide options to the graphql-transport-ws protocol handler. +type ProtocolGraphQLTransportWSHandlerOptions struct { + Logger abstractlogger.Logger + WebSocketInitFunc InitFunc + CustomKeepAliveInterval time.Duration + CustomInitTimeOutDuration time.Duration +} + +// ProtocolGraphQLTransportWSHandler is able to handle the graphql-transport-ws protocol. +type ProtocolGraphQLTransportWSHandler struct { + logger abstractlogger.Logger + reader GraphQLTransportWSMessageReader + eventHandler GraphQLTransportWSEventHandler + connectionInitialized bool + heartbeatInterval time.Duration + heartbeatStarted bool + initFunc InitFunc + connectionAcknowledged bool + connectionInitTimerStarted bool + connectionInitTimeOutCancel context.CancelFunc + connectionInitTimeOutDuration time.Duration +} + +// NewProtocolGraphQLTransportWSHandler creates a new ProtocolGraphQLTransportWSHandler with default options. +func NewProtocolGraphQLTransportWSHandler(client subscription.TransportClient) (*ProtocolGraphQLTransportWSHandler, error) { + return NewProtocolGraphQLTransportWSHandlerWithOptions(client, ProtocolGraphQLTransportWSHandlerOptions{}) +} + +// NewProtocolGraphQLTransportWSHandlerWithOptions creates a new ProtocolGraphQLTransportWSHandler. It requires an option struct. +func NewProtocolGraphQLTransportWSHandlerWithOptions(client subscription.TransportClient, opts ProtocolGraphQLTransportWSHandlerOptions) (*ProtocolGraphQLTransportWSHandler, error) { + protocolHandler := &ProtocolGraphQLTransportWSHandler{ + logger: abstractlogger.Noop{}, + reader: GraphQLTransportWSMessageReader{ + logger: abstractlogger.Noop{}, + }, + eventHandler: GraphQLTransportWSEventHandler{ + logger: abstractlogger.Noop{}, + Writer: GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: client, + mu: &sync.Mutex{}, + }, + }, + initFunc: opts.WebSocketInitFunc, + } + + if opts.Logger != nil { + protocolHandler.logger = opts.Logger + protocolHandler.reader.logger = opts.Logger + protocolHandler.eventHandler.logger = opts.Logger + protocolHandler.eventHandler.Writer.logger = opts.Logger + } + + if opts.CustomKeepAliveInterval != 0 { + protocolHandler.heartbeatInterval = opts.CustomKeepAliveInterval + } else { + parsedKeepAliveInterval, err := time.ParseDuration(subscription.DefaultKeepAliveInterval) + if err != nil { + return nil, err + } + protocolHandler.heartbeatInterval = parsedKeepAliveInterval + } + + if opts.CustomInitTimeOutDuration != 0 { + protocolHandler.connectionInitTimeOutDuration = opts.CustomInitTimeOutDuration + } else { + timeOutDuration, err := time.ParseDuration(DefaultConnectionInitTimeOut) + if err != nil { + return nil, err + } + protocolHandler.connectionInitTimeOutDuration = timeOutDuration + } + + // Pass event functions + protocolHandler.eventHandler.OnConnectionOpened = protocolHandler.startConnectionInitTimer + + return protocolHandler, nil +} + +// Handle will handle the actual graphql-transport-ws protocol messages. It's an implementation of subscription.Protocol. +func (p *ProtocolGraphQLTransportWSHandler) Handle(ctx context.Context, engine subscription.Engine, data []byte) error { + if !p.connectionAcknowledged && !p.connectionInitTimerStarted { + p.startConnectionInitTimer() + } + + message, err := p.reader.Read(data) + if err != nil { + var jsonSyntaxError *json.SyntaxError + if errors.As(err, &jsonSyntaxError) { + p.closeConnectionWithReason(NewCloseReason(4400, "JSON syntax error")) + return nil + } + p.logger.Error("websocket.ProtocolGraphQLTransportWSHandler.Handle: on message reading", + abstractlogger.Error(err), + abstractlogger.ByteString("payload", data), + ) + return err + } + switch message.Type { + case GraphQLTransportWSMessageTypeConnectionInit: + ctx, err = p.handleInit(ctx, message.Payload) + if err != nil { + p.logger.Error("websocket.ProtocolGraphQLTransportWSHandler.Handle: on handling init", + abstractlogger.Error(err), + ) + p.closeConnectionWithReason( + CompiledCloseReasonInternalServerError, + ) + } + p.startHeartbeat(ctx) + case GraphQLTransportWSMessageTypePing: + p.handlePing(message.Payload) + case GraphQLTransportWSMessageTypePong: + return nil // no need to act on pong currently (this may change in future for heartbeat checks) + case GraphQLTransportWSMessageTypeSubscribe: + return p.handleSubscribe(ctx, engine, message) + case GraphQLTransportWSMessageTypeComplete: + return p.handleComplete(engine, message.Id) + default: + p.closeConnectionWithReason( + NewCloseReason(4400, fmt.Sprintf("Invalid type '%s'", string(message.Type))), + ) + } + + return nil +} + +// EventHandler returns the underlying graphql-transport-ws event handler. It's an implementation of subscription.Protocol. +func (p *ProtocolGraphQLTransportWSHandler) EventHandler() subscription.EventHandler { + return &p.eventHandler +} + +func (p *ProtocolGraphQLTransportWSHandler) startConnectionInitTimer() { + if p.connectionInitTimerStarted { + return + } + + timeOutContext, timeOutContextCancel := context.WithCancel(context.Background()) + p.connectionInitTimeOutCancel = timeOutContextCancel + p.connectionInitTimerStarted = true + timeOutParams := subscription.TimeOutParams{ + Name: "connection init time out", + Logger: p.logger, + TimeOutContext: timeOutContext, + TimeOutAction: func() { + p.closeConnectionWithReason( + NewCloseReason(4408, "Connection initialisation timeout"), + ) + }, + TimeOutDuration: p.connectionInitTimeOutDuration, + } + go subscription.TimeOutChecker(timeOutParams) +} + +func (p *ProtocolGraphQLTransportWSHandler) stopConnectionInitTimer() bool { + if p.connectionInitTimeOutCancel == nil { + return false + } + + p.connectionInitTimeOutCancel() + p.connectionInitTimeOutCancel = nil + return true +} + +func (p *ProtocolGraphQLTransportWSHandler) startHeartbeat(ctx context.Context) { + if p.heartbeatStarted { + return + } + + p.heartbeatStarted = true + go p.heartbeat(ctx) +} + +func (p *ProtocolGraphQLTransportWSHandler) heartbeat(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(p.heartbeatInterval): + p.eventHandler.HandleWriteEvent(GraphQLTransportWSMessageTypePong, "", []byte(GraphQLTransportWSHeartbeatPayload), nil) + } + } +} + +func (p *ProtocolGraphQLTransportWSHandler) handleInit(ctx context.Context, payload []byte) (context.Context, error) { + if p.connectionInitialized { + p.closeConnectionWithReason( + NewCloseReason(4429, "Too many initialisation requests"), + ) + return ctx, nil + } + + initCtx := ctx + if p.initFunc != nil && len(payload) > 0 { + // check initial payload to see whether to accept the websocket connection + var err error + if initCtx, err = p.initFunc(ctx, payload); err != nil { + return initCtx, err + } + } + + if p.stopConnectionInitTimer() { + p.eventHandler.HandleWriteEvent(GraphQLTransportWSMessageTypeConnectionAck, "", nil, nil) + } else { + p.closeConnectionWithReason(CompiledCloseReasonInternalServerError) + } + p.connectionInitialized = true + return initCtx, nil +} + +func (p *ProtocolGraphQLTransportWSHandler) handlePing(payload []byte) { + // Pong should return the same payload as ping. + // https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#pings_and_pongs_the_heartbeat_of_websockets + p.eventHandler.HandleWriteEvent(GraphQLTransportWSMessageTypePong, "", payload, nil) +} + +func (p *ProtocolGraphQLTransportWSHandler) handleSubscribe(ctx context.Context, engine subscription.Engine, message *GraphQLTransportWSMessage) error { + if !p.connectionInitialized { + p.closeConnectionWithReason( + NewCloseReason(4401, "Unauthorized"), + ) + return nil + } + + subscribePayload, err := p.reader.DeserializeSubscribePayload(message) + if err != nil { + return err + } + + enginePayload := graphql.Request{ + OperationName: subscribePayload.OperationName, + Query: subscribePayload.Query, + Variables: subscribePayload.Variables, + } + + enginePayloadBytes, err := json.Marshal(enginePayload) + if err != nil { + return err + } + + return engine.StartOperation(ctx, message.Id, enginePayloadBytes, &p.eventHandler) +} + +func (p *ProtocolGraphQLTransportWSHandler) handleComplete(engine subscription.Engine, id string) error { + return engine.StopSubscription(id, &p.eventHandler) +} + +func (p *ProtocolGraphQLTransportWSHandler) closeConnectionWithReason(reason interface{}) { + err := p.eventHandler.Writer.Client.DisconnectWithReason( + reason, + ) + if err != nil { + p.logger.Error("websocket.ProtocolGraphQLTransportWSHandler.closeConnectionWithReason: after trying to disconnect with reason", + abstractlogger.Error(err), + ) + } +} + +// Interface guards +var _ subscription.EventHandler = (*GraphQLTransportWSEventHandler)(nil) +var _ subscription.Protocol = (*ProtocolGraphQLTransportWSHandler)(nil) diff --git a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go new file mode 100644 index 000000000..56440880d --- /dev/null +++ b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go @@ -0,0 +1,569 @@ +package websocket + +import ( + "context" + "errors" + "runtime" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + + "github.com/wundergraph/graphql-go-tools/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +func TestGraphQLTransportWSMessageReader_Read(t *testing.T) { + t.Run("should read a minimal message", func(t *testing.T) { + data := []byte(`{ "type": "connection_init" }`) + expectedMessage := &GraphQLTransportWSMessage{ + Type: "connection_init", + } + + reader := GraphQLTransportWSMessageReader{ + logger: abstractlogger.Noop{}, + } + message, err := reader.Read(data) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, message) + }) + + t.Run("should message with json payload", func(t *testing.T) { + data := []byte(`{ "id": "1", "type": "connection_init", "payload": { "Authorization": "Bearer ey123" } }`) + expectedMessage := &GraphQLTransportWSMessage{ + Id: "1", + Type: "connection_init", + Payload: []byte(`{ "Authorization": "Bearer ey123" }`), + } + + reader := GraphQLTransportWSMessageReader{ + logger: abstractlogger.Noop{}, + } + message, err := reader.Read(data) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, message) + }) + + t.Run("should read and deserialize subscribe message", func(t *testing.T) { + data := []byte(`{ + "id": "1", + "type": "subscribe", + "payload": { + "operationName": "MyQuery", + "query": "query MyQuery($name: String) { hello(name: $name) }", + "variables": { "name": "Udo" }, + "extensions": { "Authorization": "Bearer ey123" } + } +}`) + expectedMessage := &GraphQLTransportWSMessage{ + Id: "1", + Type: "subscribe", + Payload: []byte(`{ + "operationName": "MyQuery", + "query": "query MyQuery($name: String) { hello(name: $name) }", + "variables": { "name": "Udo" }, + "extensions": { "Authorization": "Bearer ey123" } + }`), + } + + reader := GraphQLTransportWSMessageReader{ + logger: abstractlogger.Noop{}, + } + message, err := reader.Read(data) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, message) + + expectedPayload := &GraphQLTransportWSMessageSubscribePayload{ + OperationName: "MyQuery", + Query: "query MyQuery($name: String) { hello(name: $name) }", + Variables: []byte(`{ "name": "Udo" }`), + Extensions: []byte(`{ "Authorization": "Bearer ey123" }`), + } + actualPayload, err := reader.DeserializeSubscribePayload(message) + assert.NoError(t, err) + assert.Equal(t, expectedPayload, actualPayload) + }) +} + +func TestGraphQLTransportWSMessageWriter_WriteConnectionAck(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteConnectionAck() + assert.Error(t, err) + }) + t.Run("should successfully write ack message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"connection_ack"}`) + err := writer.WriteConnectionAck() + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLTransportWSMessageWriter_WritePing(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WritePing(nil) + assert.Error(t, err) + }) + t.Run("should successfully write ping message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"ping"}`) + err := writer.WritePing(nil) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should successfully write ping message with payload to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"ping","payload":{"connected_since":"10min"}}`) + err := writer.WritePing([]byte(`{"connected_since":"10min"}`)) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLTransportWSMessageWriter_WritePong(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WritePong(nil) + assert.Error(t, err) + }) + t.Run("should successfully write pong message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"pong"}`) + err := writer.WritePong(nil) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should successfully write pong message with payload to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"pong","payload":{"connected_since":"10min"}}`) + err := writer.WritePong([]byte(`{"connected_since":"10min"}`)) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLTransportWSMessageWriter_WriteNext(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteNext("1", nil) + assert.Error(t, err) + }) + t.Run("should successfully write next message with payload to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"id":"1","type":"next","payload":{"data":{"hello":"world"}}}`) + err := writer.WriteNext("1", []byte(`{"data":{"hello":"world"}}`)) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLTransportWSMessageWriter_WriteError(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteError("1", nil) + assert.Error(t, err) + }) + t.Run("should successfully write error message with payload to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"id":"1","type":"error","payload":[{"message":"request error"}]}`) + requestErrors := graphql.RequestErrorsFromError(errors.New("request error")) + err := writer.WriteError("1", requestErrors) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLTransportWSMessageWriter_WriteComplete(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteComplete("1") + assert.Error(t, err) + }) + t.Run("should successfully write complete message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"id":"1","type":"complete"}`) + err := writer.WriteComplete("1") + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLTransportWSEventHandler_Emit(t *testing.T) { + t.Run("should write on completed", func(t *testing.T) { + testClient := NewTestClient(false) + eventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + eventHandler.Emit(subscription.EventTypeOnSubscriptionCompleted, "1", nil, nil) + expectedMessage := []byte(`{"id":"1","type":"complete"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write on data", func(t *testing.T) { + testClient := NewTestClient(false) + eventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + eventHandler.Emit(subscription.EventTypeOnSubscriptionData, "1", []byte(`{ "data": { "hello": "world" } }`), nil) + expectedMessage := []byte(`{"id":"1","type":"next","payload":{"data":{"hello":"world"}}}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write on non-subscription execution result", func(t *testing.T) { + testClient := NewTestClient(false) + eventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + go func() { + eventHandler.Emit(subscription.EventTypeOnNonSubscriptionExecutionResult, "1", []byte(`{ "data": { "hello": "world" } }`), nil) + }() + + assert.Eventually(t, func() bool { + expectedDataMessage := []byte(`{"id":"1","type":"next","payload":{"data":{"hello":"world"}}}`) + actualDataMessage := testClient.readMessageToClient() + assert.Equal(t, expectedDataMessage, actualDataMessage) + expectedCompleteMessage := []byte(`{"id":"1","type":"complete"}`) + actualCompleteMessage := testClient.readMessageToClient() + assert.Equal(t, expectedCompleteMessage, actualCompleteMessage) + return true + }, 1*time.Second, 2*time.Millisecond) + }) + t.Run("should write on error", func(t *testing.T) { + testClient := NewTestClient(false) + eventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + eventHandler.Emit(subscription.EventTypeOnError, "1", nil, errors.New("error occurred")) + expectedMessage := []byte(`{"id":"1","type":"error","payload":[{"message":"error occurred"}]}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should execute the OnConnectionOpened event function", func(t *testing.T) { + counter := 0 + testClient := NewTestClient(false) + eventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + eventHandler.OnConnectionOpened = func() { + counter++ + } + eventHandler.Emit(subscription.EventTypeOnConnectionOpened, "", nil, nil) + assert.Equal(t, counter, 1) + }) + t.Run("should disconnect on duplicated subscriber id", func(t *testing.T) { + testClient := NewTestClient(false) + eventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + eventHandler.Emit(subscription.EventTypeOnDuplicatedSubscriberID, "1", nil, errors.New("subscriber already exists")) + assert.False(t, testClient.IsConnected()) + }) +} + +func TestGraphQLTransportWSWriteEventHandler_HandleWriteEvent(t *testing.T) { + t.Run("should write connection_ack", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + writeEventHandler.HandleWriteEvent(GraphQLTransportWSMessageTypeConnectionAck, "", nil, nil) + expectedMessage := []byte(`{"type":"connection_ack"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write ping", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + writeEventHandler.HandleWriteEvent(GraphQLTransportWSMessageTypePing, "", nil, nil) + expectedMessage := []byte(`{"type":"ping"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write pong", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + writeEventHandler.HandleWriteEvent(GraphQLTransportWSMessageTypePong, "", nil, nil) + expectedMessage := []byte(`{"type":"pong"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should close connection on invalid type", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLTransportWSEventHandler(testClient) + writeEventHandler.HandleWriteEvent(GraphQLTransportWSMessageType("invalid"), "", nil, nil) + assert.False(t, writeEventHandler.Writer.Client.IsConnected()) + }) +} + +func TestProtocolGraphQLTransportWSHandler_Handle(t *testing.T) { + t.Run("should close connection when an unexpected message type is used", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + err := protocol.Handle(ctx, mockEngine, []byte(`{"type":"something"}`)) + assert.NoError(t, err) + assert.False(t, testClient.IsConnected()) + }) + + t.Run("for connection_init", func(t *testing.T) { + t.Run("should time out if no connection_init message is sent", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("this test fails on Windows due to different timings than unix, consider fixing it at some point") + } + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + protocol.connectionInitTimeOutDuration = 2 * time.Millisecond + protocol.eventHandler.OnConnectionOpened = protocol.startConnectionInitTimer + + protocol.eventHandler.Emit(subscription.EventTypeOnConnectionOpened, "", nil, nil) + time.Sleep(10 * time.Millisecond) + assert.True(t, protocol.connectionInitTimerStarted) + assert.False(t, protocol.eventHandler.Writer.Client.IsConnected()) + }) + + t.Run("should close connection after multiple connection_init messages", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + protocol.connectionInitTimeOutDuration = 50 * time.Millisecond + protocol.eventHandler.OnConnectionOpened = protocol.startConnectionInitTimer + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + protocol.eventHandler.Emit(subscription.EventTypeOnConnectionOpened, "", nil, nil) + assert.Eventually(t, func() bool { + expectedAckMessage := []byte(`{"type":"connection_ack"}`) + time.Sleep(5 * time.Millisecond) + err := protocol.Handle(ctx, mockEngine, []byte(`{"type":"connection_init"}`)) + assert.NoError(t, err) + assert.Equal(t, expectedAckMessage, testClient.readMessageToClient()) + time.Sleep(1 * time.Millisecond) + err = protocol.Handle(ctx, mockEngine, []byte(`{"type":"connection_init"}`)) + assert.NoError(t, err) + assert.False(t, protocol.eventHandler.Writer.Client.IsConnected()) + return true + }, 1*time.Second, 2*time.Millisecond) + + }) + + t.Run("should not time out if connection_init message is sent before time out", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("this test fails on Windows due to different timings than unix, consider fixing it at some point") + } + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + protocol.heartbeatInterval = 4 * time.Millisecond + protocol.connectionInitTimeOutDuration = 25 * time.Millisecond + protocol.eventHandler.OnConnectionOpened = protocol.startConnectionInitTimer + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + protocol.eventHandler.Emit(subscription.EventTypeOnConnectionOpened, "", nil, nil) + assert.Eventually(t, func() bool { + expectedAckMessage := []byte(`{"type":"connection_ack"}`) + expectedHeartbeatMessage := []byte(`{"type":"pong","payload":{"type":"heartbeat"}}`) + time.Sleep(1 * time.Millisecond) + err := protocol.Handle(ctx, mockEngine, []byte(`{"type":"connection_init"}`)) + assert.NoError(t, err) + assert.Equal(t, expectedAckMessage, testClient.readMessageToClient()) + time.Sleep(6 * time.Millisecond) + assert.Equal(t, expectedHeartbeatMessage, testClient.readMessageToClient()) + time.Sleep(50 * time.Millisecond) + assert.True(t, protocol.eventHandler.Writer.Client.IsConnected()) + assert.True(t, protocol.connectionInitTimerStarted) + assert.Nil(t, protocol.connectionInitTimeOutCancel) + return true + }, 1*time.Second, 2*time.Millisecond) + + }) + }) + + t.Run("should return pong on ping", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + assert.Eventually(t, func() bool { + inputMessage := []byte(`{"type":"ping","payload":{"status":"ok"}}`) + expectedMessage := []byte(`{"type":"pong","payload":{"status":"ok"}}`) + err := protocol.Handle(ctx, mockEngine, inputMessage) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + return true + }, 1*time.Second, 2*time.Millisecond) + }) + + t.Run("should handle subscribe", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + operation := []byte(`{"operationName":"Hello","query":"query Hello { hello }"}`) + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + mockEngine.EXPECT().StartOperation(gomock.Eq(ctx), gomock.Eq("1"), gomock.Eq(operation), gomock.Eq(&protocol.eventHandler)) + + assert.Eventually(t, func() bool { + inputMessage := []byte(`{"id":"1","type":"subscribe","payload":{"operationName":"Hello","query":"query Hello { hello }"}}`) + err := protocol.Handle(ctx, mockEngine, inputMessage) + assert.NoError(t, err) + return true + }, 1*time.Second, 2*time.Millisecond) + }) + + t.Run("should handle complete", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + mockEngine.EXPECT().StopSubscription(gomock.Eq("1"), gomock.Eq(&protocol.eventHandler)) + + assert.Eventually(t, func() bool { + inputMessage := []byte(`{"id":"1","type":"complete"}`) + err := protocol.Handle(ctx, mockEngine, inputMessage) + assert.NoError(t, err) + return true + }, 1*time.Second, 2*time.Millisecond) + }) + + t.Run("should allow pong messages from client", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + assert.Eventually(t, func() bool { + inputMessage := []byte(`{"type":"pong"}`) + err := protocol.Handle(ctx, mockEngine, inputMessage) + assert.NoError(t, err) + assert.True(t, testClient.IsConnected()) + return true + }, 1*time.Second, 2*time.Millisecond) + }) + + t.Run("should not panic on broken input", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLTransportWSHandler(testClient) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + mockEngine.EXPECT().StopSubscription(gomock.Eq("1"), gomock.Eq(&protocol.eventHandler)) + + assert.Eventually(t, func() bool { + inputMessage := []byte(`{"type":"connection_init","payload":{something}}`) + err := protocol.Handle(ctx, mockEngine, inputMessage) + assert.NoError(t, err) + assert.False(t, testClient.IsConnected()) + return true + }, 1*time.Second, 2*time.Millisecond) + }) +} + +func NewTestGraphQLTransportWSEventHandler(testClient subscription.TransportClient) GraphQLTransportWSEventHandler { + return GraphQLTransportWSEventHandler{ + logger: abstractlogger.Noop{}, + Writer: GraphQLTransportWSMessageWriter{ + logger: abstractlogger.Noop{}, + mu: &sync.Mutex{}, + Client: testClient, + }, + } +} + +func NewTestProtocolGraphQLTransportWSHandler(testClient subscription.TransportClient) *ProtocolGraphQLTransportWSHandler { + return &ProtocolGraphQLTransportWSHandler{ + logger: abstractlogger.Noop{}, + reader: GraphQLTransportWSMessageReader{ + logger: abstractlogger.Noop{}, + }, + eventHandler: NewTestGraphQLTransportWSEventHandler(testClient), + heartbeatInterval: 30, + connectionInitTimeOutDuration: 10 * time.Second, + } +} diff --git a/v2/pkg/subscription/websocket/protocol_graphql_ws.go b/v2/pkg/subscription/websocket/protocol_graphql_ws.go new file mode 100644 index 000000000..8d16d17f7 --- /dev/null +++ b/v2/pkg/subscription/websocket/protocol_graphql_ws.go @@ -0,0 +1,359 @@ +package websocket + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +// GraphQLWSMessageType is a type that defines graphql-ws message type names. +type GraphQLWSMessageType string + +const ( + GraphQLWSMessageTypeConnectionInit GraphQLWSMessageType = "connection_init" + GraphQLWSMessageTypeConnectionAck GraphQLWSMessageType = "connection_ack" + GraphQLWSMessageTypeConnectionError GraphQLWSMessageType = "connection_error" + GraphQLWSMessageTypeConnectionTerminate GraphQLWSMessageType = "connection_terminate" + GraphQLWSMessageTypeConnectionKeepAlive GraphQLWSMessageType = "ka" + GraphQLWSMessageTypeStart GraphQLWSMessageType = "start" + GraphQLWSMessageTypeStop GraphQLWSMessageType = "stop" + GraphQLWSMessageTypeData GraphQLWSMessageType = "data" + GraphQLWSMessageTypeError GraphQLWSMessageType = "error" + GraphQLWSMessageTypeComplete GraphQLWSMessageType = "complete" +) + +var ErrGraphQLWSUnexpectedMessageType = errors.New("unexpected message type") + +// GraphQLWSMessage is a struct that can be (de)serialized to graphql-ws message format. +type GraphQLWSMessage struct { + Id string `json:"id,omitempty"` + Type GraphQLWSMessageType `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +// GraphQLWSMessageReader can be used to read graphql-ws messages. +type GraphQLWSMessageReader struct { + logger abstractlogger.Logger +} + +// Read deserializes a byte slice to the GraphQLWSMessage struct. +func (g *GraphQLWSMessageReader) Read(data []byte) (*GraphQLWSMessage, error) { + var message GraphQLWSMessage + err := json.Unmarshal(data, &message) + if err != nil { + g.logger.Error("websocket.GraphQLWSMessageReader.Read: on json unmarshal", + abstractlogger.Error(err), + abstractlogger.ByteString("data", data), + ) + + return nil, err + } + return &message, nil +} + +// GraphQLWSMessageWriter can be used to write graphql-ws messages to a transport client. +type GraphQLWSMessageWriter struct { + logger abstractlogger.Logger + mu *sync.Mutex + Client subscription.TransportClient +} + +// WriteData writes a message of type 'data' to the transport client. +func (g *GraphQLWSMessageWriter) WriteData(id string, responseData []byte) error { + message := &GraphQLWSMessage{ + Id: id, + Type: GraphQLWSMessageTypeData, + Payload: responseData, + } + return g.write(message) +} + +// WriteComplete writes a message of type 'complete' to the transport client. +func (g *GraphQLWSMessageWriter) WriteComplete(id string) error { + message := &GraphQLWSMessage{ + Id: id, + Type: GraphQLWSMessageTypeComplete, + Payload: nil, + } + return g.write(message) +} + +// WriteKeepAlive writes a message of type 'ka' to the transport client. +func (g *GraphQLWSMessageWriter) WriteKeepAlive() error { + message := &GraphQLWSMessage{ + Type: GraphQLWSMessageTypeConnectionKeepAlive, + Payload: nil, + } + return g.write(message) +} + +// WriteTerminate writes a message of type 'connection_terminate' to the transport client. +func (g *GraphQLWSMessageWriter) WriteTerminate(reason string) error { + payloadBytes, err := json.Marshal(reason) + if err != nil { + return err + } + message := &GraphQLWSMessage{ + Type: GraphQLWSMessageTypeConnectionTerminate, + Payload: payloadBytes, + } + return g.write(message) +} + +// WriteConnectionError writes a message of type 'connection_error' to the transport client. +func (g *GraphQLWSMessageWriter) WriteConnectionError(reason string) error { + payloadBytes, err := json.Marshal(reason) + if err != nil { + return err + } + message := &GraphQLWSMessage{ + Type: GraphQLWSMessageTypeConnectionError, + Payload: payloadBytes, + } + return g.write(message) +} + +// WriteError writes a message of type 'error' to the transport client. +func (g *GraphQLWSMessageWriter) WriteError(id string, errors graphql.RequestErrors) error { + payloadBytes, err := json.Marshal(errors) + if err != nil { + return err + } + message := &GraphQLWSMessage{ + Id: id, + Type: GraphQLWSMessageTypeError, + Payload: payloadBytes, + } + return g.write(message) +} + +// WriteAck writes a message of type 'connection_ack' to the transport client. +func (g *GraphQLWSMessageWriter) WriteAck() error { + message := &GraphQLWSMessage{ + Type: GraphQLWSMessageTypeConnectionAck, + } + return g.write(message) +} + +func (g *GraphQLWSMessageWriter) write(message *GraphQLWSMessage) error { + jsonData, err := json.Marshal(message) + if err != nil { + g.logger.Error("websocket.GraphQLWSMessageWriter.write: on json marshal", + abstractlogger.Error(err), + abstractlogger.String("id", message.Id), + abstractlogger.String("type", string(message.Type)), + abstractlogger.ByteString("payload", message.Payload), + ) + return err + } + g.mu.Lock() + defer g.mu.Unlock() + return g.Client.WriteBytesToClient(jsonData) +} + +// GraphQLWSWriteEventHandler can be used to handle subscription events and forward them to a GraphQLWSMessageWriter. +type GraphQLWSWriteEventHandler struct { + logger abstractlogger.Logger + Writer GraphQLWSMessageWriter +} + +// Emit is an implementation of subscription.EventHandler. It forwards events to the HandleWriteEvent. +func (g *GraphQLWSWriteEventHandler) Emit(eventType subscription.EventType, id string, data []byte, err error) { + messageType := GraphQLWSMessageType("") + switch eventType { + case subscription.EventTypeOnSubscriptionCompleted: + messageType = GraphQLWSMessageTypeComplete + case subscription.EventTypeOnSubscriptionData: + messageType = GraphQLWSMessageTypeData + case subscription.EventTypeOnNonSubscriptionExecutionResult: + g.HandleWriteEvent(GraphQLWSMessageTypeData, id, data, err) + g.HandleWriteEvent(GraphQLWSMessageTypeComplete, id, data, err) + return + case subscription.EventTypeOnError: + messageType = GraphQLWSMessageTypeError + case subscription.EventTypeOnDuplicatedSubscriberID: + messageType = GraphQLWSMessageTypeError + case subscription.EventTypeOnConnectionError: + messageType = GraphQLWSMessageTypeConnectionError + default: + return + } + + g.HandleWriteEvent(messageType, id, data, err) +} + +// HandleWriteEvent forwards messages to the underlying writer. +func (g *GraphQLWSWriteEventHandler) HandleWriteEvent(messageType GraphQLWSMessageType, id string, data []byte, providedErr error) { + var err error + switch messageType { + case GraphQLWSMessageTypeComplete: + err = g.Writer.WriteComplete(id) + case GraphQLWSMessageTypeData: + err = g.Writer.WriteData(id, data) + case GraphQLWSMessageTypeError: + err = g.Writer.WriteError(id, graphql.RequestErrorsFromError(providedErr)) + case GraphQLWSMessageTypeConnectionError: + err = g.Writer.WriteConnectionError(providedErr.Error()) + case GraphQLWSMessageTypeConnectionKeepAlive: + err = g.Writer.WriteKeepAlive() + case GraphQLWSMessageTypeConnectionAck: + err = g.Writer.WriteAck() + default: + g.logger.Warn("websocket.GraphQLWSWriteEventHandler.HandleWriteEvent: on write event handling with unexpected message type", + abstractlogger.Error(err), + abstractlogger.String("id", id), + abstractlogger.String("type", string(messageType)), + abstractlogger.ByteString("payload", data), + abstractlogger.Error(providedErr), + ) + return + } + if err != nil { + g.logger.Error("websocket.GraphQLWSWriteEventHandler.HandleWriteEvent: on write event handling", + abstractlogger.Error(err), + abstractlogger.String("id", id), + abstractlogger.String("type", string(messageType)), + abstractlogger.ByteString("payload", data), + abstractlogger.Error(providedErr), + ) + } +} + +// ProtocolGraphQLWSHandlerOptions can be used to provide options to the graphql-ws protocol handler. +type ProtocolGraphQLWSHandlerOptions struct { + Logger abstractlogger.Logger + WebSocketInitFunc InitFunc + CustomKeepAliveInterval time.Duration +} + +// ProtocolGraphQLWSHandler is able to handle the graphql-ws protocol. +type ProtocolGraphQLWSHandler struct { + logger abstractlogger.Logger + reader GraphQLWSMessageReader + writeEventHandler GraphQLWSWriteEventHandler + keepAliveInterval time.Duration + initFunc InitFunc +} + +// NewProtocolGraphQLWSHandler creates a new ProtocolGraphQLWSHandler with default options. +func NewProtocolGraphQLWSHandler(client subscription.TransportClient) (*ProtocolGraphQLWSHandler, error) { + return NewProtocolGraphQLWSHandlerWithOptions(client, ProtocolGraphQLWSHandlerOptions{}) +} + +// NewProtocolGraphQLWSHandlerWithOptions creates a new ProtocolGraphQLWSHandler. It requires an option struct. +func NewProtocolGraphQLWSHandlerWithOptions(client subscription.TransportClient, opts ProtocolGraphQLWSHandlerOptions) (*ProtocolGraphQLWSHandler, error) { + protocolHandler := &ProtocolGraphQLWSHandler{ + logger: abstractlogger.Noop{}, + reader: GraphQLWSMessageReader{ + logger: abstractlogger.Noop{}, + }, + writeEventHandler: GraphQLWSWriteEventHandler{ + logger: abstractlogger.Noop{}, + Writer: GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: client, + mu: &sync.Mutex{}, + }, + }, + initFunc: opts.WebSocketInitFunc, + } + + if opts.Logger != nil { + protocolHandler.logger = opts.Logger + protocolHandler.reader.logger = opts.Logger + protocolHandler.writeEventHandler.logger = opts.Logger + protocolHandler.writeEventHandler.Writer.logger = opts.Logger + } + + if opts.CustomKeepAliveInterval != 0 { + protocolHandler.keepAliveInterval = opts.CustomKeepAliveInterval + } else { + parsedKeepAliveInterval, err := time.ParseDuration(subscription.DefaultKeepAliveInterval) + if err != nil { + return nil, err + } + protocolHandler.keepAliveInterval = parsedKeepAliveInterval + } + + return protocolHandler, nil +} + +// Handle will handle the actual graphql-ws protocol messages. It's an implementation of subscription.Protocol. +func (p *ProtocolGraphQLWSHandler) Handle(ctx context.Context, engine subscription.Engine, data []byte) error { + message, err := p.reader.Read(data) + if err != nil { + var jsonSyntaxError *json.SyntaxError + if errors.As(err, &jsonSyntaxError) { + p.writeEventHandler.HandleWriteEvent(GraphQLWSMessageTypeError, "", nil, errors.New("json syntax error")) + return nil + } + p.logger.Error("websocket.ProtocolGraphQLWSHandler.Handle: on message reading", + abstractlogger.Error(err), + abstractlogger.ByteString("payload", data), + ) + return err + } + + switch message.Type { + case GraphQLWSMessageTypeConnectionInit: + ctx, err = p.handleInit(ctx, message.Payload) + if err != nil { + p.writeEventHandler.HandleWriteEvent(GraphQLWSMessageTypeConnectionError, "", nil, errors.New("failed to accept the websocket connection")) + return engine.TerminateAllSubscriptions(&p.writeEventHandler) + } + + go p.handleKeepAlive(ctx) + case GraphQLWSMessageTypeStart: + return engine.StartOperation(ctx, message.Id, message.Payload, &p.writeEventHandler) + case GraphQLWSMessageTypeStop: + return engine.StopSubscription(message.Id, &p.writeEventHandler) + case GraphQLWSMessageTypeConnectionTerminate: + return engine.TerminateAllSubscriptions(&p.writeEventHandler) + default: + p.writeEventHandler.HandleWriteEvent(GraphQLWSMessageTypeConnectionError, message.Id, nil, fmt.Errorf("%s: %s", ErrGraphQLWSUnexpectedMessageType.Error(), message.Type)) + } + + return nil +} + +// EventHandler returns the underlying graphql-ws event handler. It's an implementation of subscription.Protocol. +func (p *ProtocolGraphQLWSHandler) EventHandler() subscription.EventHandler { + return &p.writeEventHandler +} + +func (p *ProtocolGraphQLWSHandler) handleInit(ctx context.Context, payload []byte) (context.Context, error) { + initCtx := ctx + if p.initFunc != nil && len(payload) > 0 { + // check initial payload to see whether to accept the websocket connection + var err error + if initCtx, err = p.initFunc(ctx, payload); err != nil { + return initCtx, err + } + } + + p.writeEventHandler.HandleWriteEvent(GraphQLWSMessageTypeConnectionAck, "", nil, nil) + return initCtx, nil +} + +func (p *ProtocolGraphQLWSHandler) handleKeepAlive(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(p.keepAliveInterval): + p.writeEventHandler.HandleWriteEvent(GraphQLWSMessageTypeConnectionKeepAlive, "", nil, nil) + } + } +} + +// Interface guards +var _ subscription.EventHandler = (*GraphQLWSWriteEventHandler)(nil) +var _ subscription.Protocol = (*ProtocolGraphQLWSHandler)(nil) diff --git a/v2/pkg/subscription/websocket/protocol_graphql_ws_test.go b/v2/pkg/subscription/websocket/protocol_graphql_ws_test.go new file mode 100644 index 000000000..4c1420a6a --- /dev/null +++ b/v2/pkg/subscription/websocket/protocol_graphql_ws_test.go @@ -0,0 +1,411 @@ +package websocket + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + + "github.com/wundergraph/graphql-go-tools/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/pkg/subscription" +) + +func TestGraphQLWSMessageReader_Read(t *testing.T) { + data := []byte(`{ "id": "1", "type": "connection_init", "payload": { "headers": { "key": "value" } } }`) + expectedMessage := &GraphQLWSMessage{ + Id: "1", + Type: "connection_init", + Payload: json.RawMessage(`{ "headers": { "key": "value" } }`), + } + + reader := GraphQLWSMessageReader{ + logger: abstractlogger.Noop{}, + } + message, err := reader.Read(data) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, message) +} + +func TestGraphQLWSMessageWriter_WriteData(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteData("1", nil) + assert.Error(t, err) + }) + t.Run("should successfully write message data to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"id":"1","type":"data","payload":{"data":{"hello":"world"}}}`) + err := writer.WriteData("1", []byte(`{"data":{"hello":"world"}}`)) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLWSMessageWriter_WriteComplete(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteComplete("1") + assert.Error(t, err) + }) + t.Run("should successfully write complete message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"id":"1","type":"complete"}`) + err := writer.WriteComplete("1") + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLWSMessageWriter_WriteKeepAlive(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteKeepAlive() + assert.Error(t, err) + }) + t.Run("should successfully write keep-alive (ka) message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"ka"}`) + err := writer.WriteKeepAlive() + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLWSMessageWriter_WriteTerminate(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteTerminate(`failed to accept the websocket connection`) + assert.Error(t, err) + }) + t.Run("should successfully write terminate message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"connection_terminate","payload":"failed to accept the websocket connection"}`) + err := writer.WriteTerminate(`failed to accept the websocket connection`) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLWSMessageWriter_WriteConnectionError(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteConnectionError(`could not read message from client`) + assert.Error(t, err) + }) + t.Run("should successfully write connection error message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"connection_error","payload":"could not read message from client"}`) + err := writer.WriteConnectionError(`could not read message from client`) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLWSMessageWriter_WriteError(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + requestErrors := graphql.RequestErrorsFromError(errors.New("request error")) + err := writer.WriteError("1", requestErrors) + assert.Error(t, err) + }) + t.Run("should successfully write error message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"id":"1","type":"error","payload":[{"message":"request error"}]}`) + requestErrors := graphql.RequestErrorsFromError(errors.New("request error")) + err := writer.WriteError("1", requestErrors) + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLWSMessageWriter_WriteAck(t *testing.T) { + t.Run("should return error when error occurs on underlying call", func(t *testing.T) { + testClient := NewTestClient(true) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + err := writer.WriteAck() + assert.Error(t, err) + }) + t.Run("should successfully write ack message to client", func(t *testing.T) { + testClient := NewTestClient(false) + writer := GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + Client: testClient, + mu: &sync.Mutex{}, + } + expectedMessage := []byte(`{"type":"connection_ack"}`) + err := writer.WriteAck() + assert.NoError(t, err) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestGraphQLWSWriteEventHandler_Emit(t *testing.T) { + t.Run("should write on completed", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + writeEventHandler.Emit(subscription.EventTypeOnSubscriptionCompleted, "1", nil, nil) + expectedMessage := []byte(`{"id":"1","type":"complete"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write on data", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + writeEventHandler.Emit(subscription.EventTypeOnSubscriptionData, "1", []byte(`{ "data": { "hello": "world" } }`), nil) + expectedMessage := []byte(`{"id":"1","type":"data","payload":{"data":{"hello":"world"}}}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write on error", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + writeEventHandler.Emit(subscription.EventTypeOnError, "1", nil, errors.New("error occurred")) + expectedMessage := []byte(`{"id":"1","type":"error","payload":[{"message":"error occurred"}]}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write on duplicated subscriber id", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + writeEventHandler.Emit(subscription.EventTypeOnDuplicatedSubscriberID, "1", nil, subscription.ErrSubscriberIDAlreadyExists) + expectedMessage := []byte(`{"id":"1","type":"error","payload":[{"message":"subscriber id already exists"}]}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write on connection_error", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + writeEventHandler.Emit(subscription.EventTypeOnConnectionError, "", nil, errors.New("connection error occurred")) + expectedMessage := []byte(`{"type":"connection_error","payload":"connection error occurred"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write on non-subscription execution result", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + go func() { + writeEventHandler.Emit(subscription.EventTypeOnNonSubscriptionExecutionResult, "1", []byte(`{ "data": { "hello": "world" } }`), nil) + }() + + assert.Eventually(t, func() bool { + expectedDataMessage := []byte(`{"id":"1","type":"data","payload":{"data":{"hello":"world"}}}`) + actualDataMessage := testClient.readMessageToClient() + assert.Equal(t, expectedDataMessage, actualDataMessage) + expectedCompleteMessage := []byte(`{"id":"1","type":"complete"}`) + actualCompleteMessage := testClient.readMessageToClient() + assert.Equal(t, expectedCompleteMessage, actualCompleteMessage) + return true + }, 1*time.Second, 2*time.Millisecond) + }) +} + +func TestGraphQLWSWriteEventHandler_HandleWriteEvent(t *testing.T) { + t.Run("should write keep_alive", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + writeEventHandler.HandleWriteEvent(GraphQLWSMessageTypeConnectionKeepAlive, "", nil, nil) + expectedMessage := []byte(`{"type":"ka"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) + t.Run("should write ack", func(t *testing.T) { + testClient := NewTestClient(false) + writeEventHandler := NewTestGraphQLWSWriteEventHandler(testClient) + writeEventHandler.HandleWriteEvent(GraphQLWSMessageTypeConnectionAck, "", nil, nil) + expectedMessage := []byte(`{"type":"connection_ack"}`) + assert.Equal(t, expectedMessage, testClient.readMessageToClient()) + }) +} + +func TestProtocolGraphQLWSHandler_Handle(t *testing.T) { + t.Run("should return connection_error when an unexpected message type is used", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLWSHandler(testClient) + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + expectedMessage := []byte(`{"type":"connection_error","payload":"unexpected message type: something"}`) + err := protocol.Handle(ctx, mockEngine, []byte(`{"type":"something"}`)) + assert.NoError(t, err) + assert.Equal(t, testClient.readMessageToClient(), expectedMessage) + }) + + t.Run("should terminate connections on connection_terminate from client", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLWSHandler(testClient) + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + mockEngine.EXPECT().TerminateAllSubscriptions(gomock.Eq(protocol.EventHandler())) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + err := protocol.Handle(ctx, mockEngine, []byte(`{"type":"connection_terminate"}`)) + assert.NoError(t, err) + }) + + t.Run("should init connection and respond with ack and ka", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLWSHandler(testClient) + protocol.keepAliveInterval = 5 * time.Millisecond + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + assert.Eventually(t, func() bool { + expectedMessageAck := []byte(`{"type":"connection_ack"}`) + expectedMessageKeepAlive := []byte(`{"type":"ka"}`) + err := protocol.Handle(ctx, mockEngine, []byte(`{"type":"connection_init"}`)) + assert.NoError(t, err) + assert.Equal(t, expectedMessageAck, testClient.readMessageToClient()) + + time.Sleep(8 * time.Millisecond) + assert.Equal(t, expectedMessageKeepAlive, testClient.readMessageToClient()) + cancelFunc() + + return true + }, 1*time.Second, 5*time.Millisecond) + + }) + + t.Run("should start an operation on start from client", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLWSHandler(testClient) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + mockEngine.EXPECT().StartOperation(gomock.Eq(ctx), "1", []byte(`{"query":"{ hello }"}`), gomock.Eq(protocol.EventHandler())) + + err := protocol.Handle(ctx, mockEngine, []byte(`{"id":"1","type":"start","payload":{"query":"{ hello }"}}`)) + assert.NoError(t, err) + }) + + t.Run("should stop a subscription on stop from client", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLWSHandler(testClient) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + mockEngine.EXPECT().StopSubscription("1", gomock.Eq(protocol.EventHandler())) + + err := protocol.Handle(ctx, mockEngine, []byte(`{"id":"1","type":"stop"}`)) + assert.NoError(t, err) + }) + + t.Run("should not panic on broken input", func(t *testing.T) { + testClient := NewTestClient(false) + protocol := NewTestProtocolGraphQLWSHandler(testClient) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + ctrl := gomock.NewController(t) + mockEngine := NewMockEngine(ctrl) + + err := protocol.Handle(ctx, mockEngine, []byte(`{"type":"connection_init","payload":{something}}`)) + assert.NoError(t, err) + + expectedMessage := []byte(`{"type":"error","payload":[{"message":"json syntax error"}]}`) + actualMessage := testClient.readMessageToClient() + assert.Equal(t, expectedMessage, actualMessage) + }) +} + +func NewTestGraphQLWSWriteEventHandler(testClient subscription.TransportClient) GraphQLWSWriteEventHandler { + return GraphQLWSWriteEventHandler{ + logger: abstractlogger.Noop{}, + Writer: GraphQLWSMessageWriter{ + logger: abstractlogger.Noop{}, + mu: &sync.Mutex{}, + Client: testClient, + }, + } +} + +func NewTestProtocolGraphQLWSHandler(testClient subscription.TransportClient) *ProtocolGraphQLWSHandler { + return &ProtocolGraphQLWSHandler{ + logger: abstractlogger.Noop{}, + reader: GraphQLWSMessageReader{ + logger: abstractlogger.Noop{}, + }, + writeEventHandler: NewTestGraphQLWSWriteEventHandler(testClient), + keepAliveInterval: 30, + } +} From a913d72c97f327861d7c4a4ddb6dceb4962adf8c Mon Sep 17 00:00:00 2001 From: fiam Date: Wed, 4 Oct 2023 15:30:55 +0100 Subject: [PATCH 2/5] fix: remaining imports --- v2/pkg/subscription/engine.go | 8 ++++---- v2/pkg/subscription/engine_test.go | 4 ++-- v2/pkg/subscription/executor.go | 4 ++-- v2/pkg/subscription/executor_mock_test.go | 4 ++-- v2/pkg/subscription/websocket/client.go | 2 +- v2/pkg/subscription/websocket/client_test.go | 2 +- v2/pkg/subscription/websocket/engine_mock_test.go | 2 +- v2/pkg/subscription/websocket/handler.go | 2 +- v2/pkg/subscription/websocket/handler_test.go | 12 ++++++------ .../websocket/protocol_graphql_transport_ws.go | 4 ++-- .../websocket/protocol_graphql_transport_ws_test.go | 4 ++-- v2/pkg/subscription/websocket/protocol_graphql_ws.go | 4 ++-- .../websocket/protocol_graphql_ws_test.go | 4 ++-- 13 files changed, 28 insertions(+), 28 deletions(-) diff --git a/v2/pkg/subscription/engine.go b/v2/pkg/subscription/engine.go index 313bfef9f..a2992b05f 100644 --- a/v2/pkg/subscription/engine.go +++ b/v2/pkg/subscription/engine.go @@ -12,8 +12,8 @@ import ( "github.com/jensneuse/abstractlogger" - "github.com/wundergraph/graphql-go-tools/pkg/ast" - "github.com/wundergraph/graphql-go-tools/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" ) type errOnBeforeStartHookFailure struct { @@ -100,8 +100,8 @@ func (e *ExecutorEngine) handleOnBeforeStart(executor Executor) error { if hook := e.engine.GetWebsocketBeforeStartHook(); hook != nil { return hook.OnBeforeStart(e.reqCtx, e.operation) } - case *ExecutorV1: - // do nothing + default: + return fmt.Errorf("unsupported executor type: %T", executor) } return nil diff --git a/v2/pkg/subscription/engine_test.go b/v2/pkg/subscription/engine_test.go index 5ad01ae34..1d3c89a56 100644 --- a/v2/pkg/subscription/engine_test.go +++ b/v2/pkg/subscription/engine_test.go @@ -13,8 +13,8 @@ import ( "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" - "github.com/wundergraph/graphql-go-tools/pkg/ast" - "github.com/wundergraph/graphql-go-tools/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" ) func TestExecutorEngine_StartOperation(t *testing.T) { diff --git a/v2/pkg/subscription/executor.go b/v2/pkg/subscription/executor.go index 81c2ae2dc..c77d50b89 100644 --- a/v2/pkg/subscription/executor.go +++ b/v2/pkg/subscription/executor.go @@ -5,8 +5,8 @@ package subscription import ( "context" - "github.com/wundergraph/graphql-go-tools/pkg/ast" - "github.com/wundergraph/graphql-go-tools/pkg/engine/resolve" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // Executor is an abstraction for executing a GraphQL engine diff --git a/v2/pkg/subscription/executor_mock_test.go b/v2/pkg/subscription/executor_mock_test.go index 6d7716678..2b2681d86 100644 --- a/v2/pkg/subscription/executor_mock_test.go +++ b/v2/pkg/subscription/executor_mock_test.go @@ -9,8 +9,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - ast "github.com/wundergraph/graphql-go-tools/pkg/ast" - resolve "github.com/wundergraph/graphql-go-tools/pkg/engine/resolve" + ast "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + resolve "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // MockExecutor is a mock of Executor interface. diff --git a/v2/pkg/subscription/websocket/client.go b/v2/pkg/subscription/websocket/client.go index 003eccdab..971686a1c 100644 --- a/v2/pkg/subscription/websocket/client.go +++ b/v2/pkg/subscription/websocket/client.go @@ -10,7 +10,7 @@ import ( "github.com/gobwas/ws/wsutil" "github.com/jensneuse/abstractlogger" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) // CloseReason is type that is used to provide a close reason to Client.DisconnectWithReason. diff --git a/v2/pkg/subscription/websocket/client_test.go b/v2/pkg/subscription/websocket/client_test.go index 85be6940e..d39a99f81 100644 --- a/v2/pkg/subscription/websocket/client_test.go +++ b/v2/pkg/subscription/websocket/client_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) type testServerWebsocketResponse struct { diff --git a/v2/pkg/subscription/websocket/engine_mock_test.go b/v2/pkg/subscription/websocket/engine_mock_test.go index da14ffc45..673d9db04 100644 --- a/v2/pkg/subscription/websocket/engine_mock_test.go +++ b/v2/pkg/subscription/websocket/engine_mock_test.go @@ -9,7 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - subscription "github.com/wundergraph/graphql-go-tools/pkg/subscription" + subscription "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) // MockEngine is a mock of Engine interface. diff --git a/v2/pkg/subscription/websocket/handler.go b/v2/pkg/subscription/websocket/handler.go index 226b8753e..f92e8c7a8 100644 --- a/v2/pkg/subscription/websocket/handler.go +++ b/v2/pkg/subscription/websocket/handler.go @@ -8,7 +8,7 @@ import ( "github.com/jensneuse/abstractlogger" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) const ( diff --git a/v2/pkg/subscription/websocket/handler_test.go b/v2/pkg/subscription/websocket/handler_test.go index cc4fc8a84..0f068aebe 100644 --- a/v2/pkg/subscription/websocket/handler_test.go +++ b/v2/pkg/subscription/websocket/handler_test.go @@ -14,12 +14,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/wundergraph/graphql-go-tools/pkg/engine/datasource/graphql_datasource" - "github.com/wundergraph/graphql-go-tools/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/pkg/graphql" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" - "github.com/wundergraph/graphql-go-tools/pkg/testing/subscriptiontesting" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/subscriptiontesting" ) func TestHandleWithOptions(t *testing.T) { diff --git a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go index 7ecc17b56..e2d7c69a4 100644 --- a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go +++ b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go @@ -10,8 +10,8 @@ import ( "github.com/jensneuse/abstractlogger" - "github.com/wundergraph/graphql-go-tools/pkg/graphql" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) // GraphQLTransportWSMessageType is a type that defines graphql-transport-ws message type names. diff --git a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go index 56440880d..e01e1c513 100644 --- a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go +++ b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go @@ -12,8 +12,8 @@ import ( "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" - "github.com/wundergraph/graphql-go-tools/pkg/graphql" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) func TestGraphQLTransportWSMessageReader_Read(t *testing.T) { diff --git a/v2/pkg/subscription/websocket/protocol_graphql_ws.go b/v2/pkg/subscription/websocket/protocol_graphql_ws.go index 8d16d17f7..a1750d8a6 100644 --- a/v2/pkg/subscription/websocket/protocol_graphql_ws.go +++ b/v2/pkg/subscription/websocket/protocol_graphql_ws.go @@ -10,8 +10,8 @@ import ( "github.com/jensneuse/abstractlogger" - "github.com/wundergraph/graphql-go-tools/pkg/graphql" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) // GraphQLWSMessageType is a type that defines graphql-ws message type names. diff --git a/v2/pkg/subscription/websocket/protocol_graphql_ws_test.go b/v2/pkg/subscription/websocket/protocol_graphql_ws_test.go index 4c1420a6a..dd5b664ad 100644 --- a/v2/pkg/subscription/websocket/protocol_graphql_ws_test.go +++ b/v2/pkg/subscription/websocket/protocol_graphql_ws_test.go @@ -12,8 +12,8 @@ import ( "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" - "github.com/wundergraph/graphql-go-tools/pkg/graphql" - "github.com/wundergraph/graphql-go-tools/pkg/subscription" + "github.com/wundergraph/graphql-go-tools/v2/pkg/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/subscription" ) func TestGraphQLWSMessageReader_Read(t *testing.T) { From af1465be19e08378ad761030dc9cd5546aba2c66 Mon Sep 17 00:00:00 2001 From: fiam Date: Wed, 4 Oct 2023 20:33:27 +0100 Subject: [PATCH 3/5] chore: re-run code generation, fix errors --- v2/go.mod | 25 +-- v2/go.sum | 59 +++--- v2/pkg/ast/ast_string.go | 5 +- v2/pkg/ast/directive_location_string.go | 10 +- v2/pkg/introspection/introspection_enum.go | 75 +++++--- v2/pkg/subscription/engine.go | 3 +- v2/pkg/subscription/engine_mock_test.go | 2 +- v2/pkg/subscription/executor_mock_test.go | 2 +- v2/pkg/subscription/handler_mock_test.go | 2 +- .../transport_client_mock_test.go | 2 +- .../websocket/engine_mock_test.go | 2 +- .../accounts/graph/generated/generated.go | 44 ++--- .../products/graph/generated/generated.go | 182 +++++++++++++++++- .../products/graph/schema.resolvers.go | 9 + .../reviews/graph/generated/generated.go | 48 ++--- .../testing/subscriptiontesting/generated.go | 18 +- 16 files changed, 343 insertions(+), 145 deletions(-) diff --git a/v2/go.mod b/v2/go.mod index 22c5015b9..2e6e49f1f 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -9,8 +9,8 @@ require ( github.com/dave/jennifer v1.4.0 github.com/davecgh/go-spew v1.1.1 github.com/gobwas/ws v1.0.4 - github.com/golang/mock v1.4.1 - github.com/google/go-cmp v0.5.8 + github.com/golang/mock v1.6.0 + github.com/google/go-cmp v0.5.9 github.com/gorilla/websocket v1.5.0 github.com/hashicorp/golang-lru v0.5.4 github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334 @@ -24,19 +24,21 @@ require ( github.com/sebdah/goldie/v2 v2.5.3 github.com/spf13/cobra v0.0.5 github.com/spf13/viper v1.3.2 - github.com/stretchr/testify v1.7.1 + github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.11.0 github.com/tidwall/sjson v1.0.4 github.com/vektah/gqlparser/v2 v2.5.1 go.uber.org/atomic v1.9.0 + go.uber.org/multierr v1.6.0 go.uber.org/zap v1.18.1 golang.org/x/exp v0.0.0-20230203172020-98cc5a0785f9 - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 + golang.org/x/sync v0.3.0 gopkg.in/yaml.v2 v2.4.0 nhooyr.io/websocket v1.8.7 ) require ( + github.com/BurntSushi/toml v1.3.2 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect @@ -44,27 +46,26 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/klauspost/compress v1.14.4 // indirect - github.com/kr/text v0.2.0 // indirect github.com/logrusorgru/aurora/v3 v3.0.0 // indirect github.com/magiconair/properties v1.8.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.16 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml v1.6.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.1.0 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/spf13/afero v1.6.0 // indirect - github.com/spf13/cast v1.3.0 // indirect + github.com/spf13/cast v1.5.1 // indirect github.com/spf13/jwalterweatherman v1.0.0 // indirect - github.com/spf13/pflag v1.0.3 // indirect + github.com/spf13/pflag v1.0.5 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - go.uber.org/multierr v1.6.0 // indirect golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f // indirect - golang.org/x/net v0.11.0 // indirect - golang.org/x/sys v0.9.0 // indirect - golang.org/x/text v0.10.0 // indirect + golang.org/x/net v0.15.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/text v0.13.0 // indirect + golang.org/x/tools v0.13.0 // indirect gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/v2/go.sum b/v2/go.sum index 79763ac71..d1a884902 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -1,8 +1,9 @@ github.com/99designs/gqlgen v0.17.22 h1:TOcrF8t0T3I0za9JD3CB6ehq7dDEMjR9Onikf8Lc/04= github.com/99designs/gqlgen v0.17.22/go.mod h1:BMhYIhe4bp7OlCo5I2PnowSK/Wimpv/YlxfNkqZGwLo= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I= github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= +github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= @@ -22,7 +23,6 @@ github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8Nz github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/dave/jennifer v1.4.0 h1:tNJFJmLDVTLu+v05mVZ88RINa3vQqnyyWkTKWYz0CwE= github.com/dave/jennifer v1.4.0/go.mod h1:fIb+770HOpJ2fmN9EPPKOqm1vMGhB+TwXKMZhrIygKg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -30,6 +30,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -51,16 +52,16 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs= github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/mock v1.4.1 h1:ocYkMQY5RrXTYgXl7ICpV0IXwlEQGwKIsery4gyXa1U= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -90,12 +91,11 @@ github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47e github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/logrusorgru/aurora/v3 v3.0.0 h1:R6zcoZZbvVcGMvDCKo45A9U/lzYyzl5NfYIvznmDfE4= @@ -106,8 +106,9 @@ github.com/matryer/moq v0.2.7/go.mod h1:kITsx543GOENm48TUAQyJ9+SAvFSr7iGQXPoth/V github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= @@ -129,6 +130,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/r3labs/sse/v2 v2.8.1 h1:lZH+W4XOLIq88U5MIHOsLec7+R62uhz3bIi2yn0Sg8o= github.com/r3labs/sse/v2 v2.8.1/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 h1:uIkTLo0AGRc8l7h5l9r+GcYi9qfVPt6lD4/bhmzfiKo= @@ -144,14 +146,16 @@ github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.6.0 h1:xoax2sJ2DT8S8xA2paPFjDCScCNeWsg75VG0DLRreiY= github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= -github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA= +github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48= github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0 h1:XHEdyB+EcvlqZamSM4ZOMGlc93t6AcsBEu9Gc1vn7yk= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.3.2 h1:VUFqw5KcqRf7i70GOzW7N+Q7+gxVBkSSqiXB12+JQ4M= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -160,8 +164,9 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.11.0 h1:C16pk7tQNiH6VlCrtIXL1w8GaOsi1X3W8KDkE1BuYd4= github.com/tidwall/gjson v1.11.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -180,6 +185,7 @@ github.com/vektah/gqlparser/v2 v2.5.1 h1:ZGu+bquAY23jsxDRcYpWjttRZrUz07LbiY77gUO github.com/vektah/gqlparser/v2 v2.5.1/go.mod h1:mPgqFBu/woKTVYWyNk8cO3kh4S/f4aRFZrvOnp3hmCs= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -199,6 +205,7 @@ golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20230203172020-98cc5a0785f9 h1:frX3nT9RkKybPnjyI+yvZh6ZucTZatCCEm9D47sZ2zo= golang.org/x/exp v0.0.0-20230203172020-98cc5a0785f9/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= @@ -206,6 +213,7 @@ golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -213,14 +221,16 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= -golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -229,38 +239,41 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= -golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -288,5 +301,3 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/v2/pkg/ast/ast_string.go b/v2/pkg/ast/ast_string.go index 74d7ec4d7..2109f9050 100644 --- a/v2/pkg/ast/ast_string.go +++ b/v2/pkg/ast/ast_string.go @@ -146,11 +146,12 @@ func _() { _ = x[UnknownPathKind-0] _ = x[ArrayIndex-1] _ = x[FieldName-2] + _ = x[InlineFragmentName-3] } -const _PathKind_name = "UnknownPathKindArrayIndexFieldName" +const _PathKind_name = "UnknownPathKindArrayIndexFieldNameInlineFragmentName" -var _PathKind_index = [...]uint8{0, 15, 25, 34} +var _PathKind_index = [...]uint8{0, 15, 25, 34, 52} func (i PathKind) String() string { if i < 0 || i >= PathKind(len(_PathKind_index)-1) { diff --git a/v2/pkg/ast/directive_location_string.go b/v2/pkg/ast/directive_location_string.go index 505cb729d..b1db189dc 100644 --- a/v2/pkg/ast/directive_location_string.go +++ b/v2/pkg/ast/directive_location_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=DirectiveLocation"; DO NOT EDIT. +// Code generated by "stringer -type=DirectiveLocation -output directive_location_string.go"; DO NOT EDIT. package ast @@ -34,9 +34,9 @@ const _DirectiveLocation_name = "DirectiveLocationUnknownExecutableDirectiveLoca var _DirectiveLocation_index = [...]uint16{0, 24, 56, 91, 130, 162, 207, 248, 289, 334, 367, 400, 433, 475, 520, 556, 588, 619, 655, 693, 740} -func (d DirectiveLocation) String() string { - if d < 0 || d >= DirectiveLocation(len(_DirectiveLocation_index)-1) { - return "DirectiveLocation(" + strconv.FormatInt(int64(d), 10) + ")" +func (i DirectiveLocation) String() string { + if i < 0 || i >= DirectiveLocation(len(_DirectiveLocation_index)-1) { + return "DirectiveLocation(" + strconv.FormatInt(int64(i), 10) + ")" } - return _DirectiveLocation_name[_DirectiveLocation_index[d]:_DirectiveLocation_index[d+1]] + return _DirectiveLocation_name[_DirectiveLocation_index[i]:_DirectiveLocation_index[i+1]] } diff --git a/v2/pkg/introspection/introspection_enum.go b/v2/pkg/introspection/introspection_enum.go index f66b96bcc..4c2f5262a 100644 --- a/v2/pkg/introspection/introspection_enum.go +++ b/v2/pkg/introspection/introspection_enum.go @@ -1,42 +1,48 @@ -// Code generated by go-enum -// DO NOT EDIT! +// Code generated by go-enum DO NOT EDIT. +// Version: +// Revision: +// Build Date: +// Built By: package introspection import ( + "errors" "fmt" ) const ( - // SCALAR is a __TypeKind of type SCALAR + // SCALAR is a __TypeKind of type SCALAR. SCALAR __TypeKind = iota - // LIST is a __TypeKind of type LIST + // LIST is a __TypeKind of type LIST. LIST - // NONNULL is a __TypeKind of type NON_NULL + // NONNULL is a __TypeKind of type NON_NULL. NONNULL - // OBJECT is a __TypeKind of type OBJECT + // OBJECT is a __TypeKind of type OBJECT. OBJECT - // ENUM is a __TypeKind of type ENUM + // ENUM is a __TypeKind of type ENUM. ENUM - // INTERFACE is a __TypeKind of type INTERFACE + // INTERFACE is a __TypeKind of type INTERFACE. INTERFACE - // UNION is a __TypeKind of type UNION + // UNION is a __TypeKind of type UNION. UNION - // INPUTOBJECT is a __TypeKind of type INPUT_OBJECT + // INPUTOBJECT is a __TypeKind of type INPUT_OBJECT. INPUTOBJECT ) +var ErrInvalid__TypeKind = errors.New("not a valid __TypeKind") + const ___TypeKindName = "SCALARLISTNON_NULLOBJECTENUMINTERFACEUNIONINPUT_OBJECT" var ___TypeKindMap = map[__TypeKind]string{ - 0: ___TypeKindName[0:6], - 1: ___TypeKindName[6:10], - 2: ___TypeKindName[10:18], - 3: ___TypeKindName[18:24], - 4: ___TypeKindName[24:28], - 5: ___TypeKindName[28:37], - 6: ___TypeKindName[37:42], - 7: ___TypeKindName[42:54], + SCALAR: ___TypeKindName[0:6], + LIST: ___TypeKindName[6:10], + NONNULL: ___TypeKindName[10:18], + OBJECT: ___TypeKindName[18:24], + ENUM: ___TypeKindName[24:28], + INTERFACE: ___TypeKindName[28:37], + UNION: ___TypeKindName[37:42], + INPUTOBJECT: ___TypeKindName[42:54], } // String implements the Stringer interface. @@ -47,31 +53,38 @@ func (x __TypeKind) String() string { return fmt.Sprintf("__TypeKind(%d)", x) } +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x __TypeKind) IsValid() bool { + _, ok := ___TypeKindMap[x] + return ok +} + var ___TypeKindValue = map[string]__TypeKind{ - ___TypeKindName[0:6]: 0, - ___TypeKindName[6:10]: 1, - ___TypeKindName[10:18]: 2, - ___TypeKindName[18:24]: 3, - ___TypeKindName[24:28]: 4, - ___TypeKindName[28:37]: 5, - ___TypeKindName[37:42]: 6, - ___TypeKindName[42:54]: 7, + ___TypeKindName[0:6]: SCALAR, + ___TypeKindName[6:10]: LIST, + ___TypeKindName[10:18]: NONNULL, + ___TypeKindName[18:24]: OBJECT, + ___TypeKindName[24:28]: ENUM, + ___TypeKindName[28:37]: INTERFACE, + ___TypeKindName[37:42]: UNION, + ___TypeKindName[42:54]: INPUTOBJECT, } -// Parse__TypeKind attempts to convert a string to a __TypeKind +// Parse__TypeKind attempts to convert a string to a __TypeKind. func Parse__TypeKind(name string) (__TypeKind, error) { if x, ok := ___TypeKindValue[name]; ok { return x, nil } - return __TypeKind(0), fmt.Errorf("%s is not a valid __TypeKind", name) + return __TypeKind(0), fmt.Errorf("%s is %w", name, ErrInvalid__TypeKind) } -// MarshalText implements the text marshaller method -func (x *__TypeKind) MarshalText() ([]byte, error) { +// MarshalText implements the text marshaller method. +func (x __TypeKind) MarshalText() ([]byte, error) { return []byte(x.String()), nil } -// UnmarshalText implements the text unmarshaller method +// UnmarshalText implements the text unmarshaller method. func (x *__TypeKind) UnmarshalText(text []byte) error { name := string(text) tmp, err := Parse__TypeKind(name) diff --git a/v2/pkg/subscription/engine.go b/v2/pkg/subscription/engine.go index a2992b05f..6fb910a1b 100644 --- a/v2/pkg/subscription/engine.go +++ b/v2/pkg/subscription/engine.go @@ -101,7 +101,8 @@ func (e *ExecutorEngine) handleOnBeforeStart(executor Executor) error { return hook.OnBeforeStart(e.reqCtx, e.operation) } default: - return fmt.Errorf("unsupported executor type: %T", executor) + // Do nothing + break } return nil diff --git a/v2/pkg/subscription/engine_mock_test.go b/v2/pkg/subscription/engine_mock_test.go index fb5cd659b..8a630bd1f 100644 --- a/v2/pkg/subscription/engine_mock_test.go +++ b/v2/pkg/subscription/engine_mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Engine) +// Source: github.com/wundergraph/graphql-go-tools/v2/pkg/subscription (interfaces: Engine) // Package subscription is a generated GoMock package. package subscription diff --git a/v2/pkg/subscription/executor_mock_test.go b/v2/pkg/subscription/executor_mock_test.go index 2b2681d86..b75104a05 100644 --- a/v2/pkg/subscription/executor_mock_test.go +++ b/v2/pkg/subscription/executor_mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Executor,ExecutorPool) +// Source: github.com/wundergraph/graphql-go-tools/v2/pkg/subscription (interfaces: Executor,ExecutorPool) // Package subscription is a generated GoMock package. package subscription diff --git a/v2/pkg/subscription/handler_mock_test.go b/v2/pkg/subscription/handler_mock_test.go index f9124b83c..c08c660b9 100644 --- a/v2/pkg/subscription/handler_mock_test.go +++ b/v2/pkg/subscription/handler_mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Protocol,EventHandler) +// Source: github.com/wundergraph/graphql-go-tools/v2/pkg/subscription (interfaces: Protocol,EventHandler) // Package subscription is a generated GoMock package. package subscription diff --git a/v2/pkg/subscription/transport_client_mock_test.go b/v2/pkg/subscription/transport_client_mock_test.go index cf1b3850f..fc5c24cbd 100644 --- a/v2/pkg/subscription/transport_client_mock_test.go +++ b/v2/pkg/subscription/transport_client_mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: TransportClient) +// Source: github.com/wundergraph/graphql-go-tools/v2/pkg/subscription (interfaces: TransportClient) // Package subscription is a generated GoMock package. package subscription diff --git a/v2/pkg/subscription/websocket/engine_mock_test.go b/v2/pkg/subscription/websocket/engine_mock_test.go index 673d9db04..d92185d06 100644 --- a/v2/pkg/subscription/websocket/engine_mock_test.go +++ b/v2/pkg/subscription/websocket/engine_mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/wundergraph/graphql-go-tools/pkg/subscription (interfaces: Engine) +// Source: github.com/wundergraph/graphql-go-tools/v2/pkg/subscription (interfaces: Engine) // Package websocket is a generated GoMock package. package websocket diff --git a/v2/pkg/testing/federationtesting/accounts/graph/generated/generated.go b/v2/pkg/testing/federationtesting/accounts/graph/generated/generated.go index e6fc31304..f08184a3b 100644 --- a/v2/pkg/testing/federationtesting/accounts/graph/generated/generated.go +++ b/v2/pkg/testing/federationtesting/accounts/graph/generated/generated.go @@ -627,7 +627,7 @@ func (ec *executionContext) _Entity_findUserByID(ctx context.Context, field grap } res := resTmp.(*model.User) fc.Result = res - return ec.marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) + return ec.marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Entity_findUserByID(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -736,7 +736,7 @@ func (ec *executionContext) _Purchase_product(ctx context.Context, field graphql } res := resTmp.(*model.Product) fc.Result = res - return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) + return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Purchase_product(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -781,7 +781,7 @@ func (ec *executionContext) _Purchase_wallet(ctx context.Context, field graphql. } res := resTmp.(model.Wallet) fc.Result = res - return ec.marshalOWallet2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐWallet(ctx, field.Selections, res) + return ec.marshalOWallet2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐWallet(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Purchase_wallet(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -866,7 +866,7 @@ func (ec *executionContext) _Query_me(ctx context.Context, field graphql.Collect } res := resTmp.(*model.User) fc.Result = res - return ec.marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) + return ec.marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_me(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -917,7 +917,7 @@ func (ec *executionContext) _Query_identifiable(ctx context.Context, field graph } res := resTmp.(model.Identifiable) fc.Result = res - return ec.marshalOIdentifiable2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐIdentifiable(ctx, field.Selections, res) + return ec.marshalOIdentifiable2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐIdentifiable(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_identifiable(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -958,7 +958,7 @@ func (ec *executionContext) _Query_histories(ctx context.Context, field graphql. } res := resTmp.([]model.History) fc.Result = res - return ec.marshalOHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx, field.Selections, res) + return ec.marshalOHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_histories(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -999,7 +999,7 @@ func (ec *executionContext) _Query_cat(ctx context.Context, field graphql.Collec } res := resTmp.(*model.Cat) fc.Result = res - return ec.marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐCat(ctx, field.Selections, res) + return ec.marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐCat(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_cat(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1279,7 +1279,7 @@ func (ec *executionContext) _Sale_product(ctx context.Context, field graphql.Col } res := resTmp.(*model.Product) fc.Result = res - return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) + return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Sale_product(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1503,7 +1503,7 @@ func (ec *executionContext) _User_history(ctx context.Context, field graphql.Col } res := resTmp.([]model.History) fc.Result = res - return ec.marshalNHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistoryᚄ(ctx, field.Selections, res) + return ec.marshalNHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistoryᚄ(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_User_history(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -4630,7 +4630,7 @@ func (ec *executionContext) marshalNFloat2float64(ctx context.Context, sel ast.S return graphql.WrapContextMarshaler(ctx, res) } -func (ec *executionContext) marshalNHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx context.Context, sel ast.SelectionSet, v model.History) graphql.Marshaler { +func (ec *executionContext) marshalNHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx context.Context, sel ast.SelectionSet, v model.History) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -4640,7 +4640,7 @@ func (ec *executionContext) marshalNHistory2githubᚗcomᚋwundergraphᚋgraphql return ec._History(ctx, sel, v) } -func (ec *executionContext) marshalNHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistoryᚄ(ctx context.Context, sel ast.SelectionSet, v []model.History) graphql.Marshaler { +func (ec *executionContext) marshalNHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistoryᚄ(ctx context.Context, sel ast.SelectionSet, v []model.History) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup isLen1 := len(v) == 1 @@ -4664,7 +4664,7 @@ func (ec *executionContext) marshalNHistory2ᚕgithubᚗcomᚋwundergraphᚋgrap if !isLen1 { defer wg.Done() } - ret[i] = ec.marshalNHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx, sel, v[i]) + ret[i] = ec.marshalNHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx, sel, v[i]) } if isLen1 { f(i) @@ -4714,7 +4714,7 @@ func (ec *executionContext) marshalNInt2int(ctx context.Context, sel ast.Selecti return res } -func (ec *executionContext) marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { +func (ec *executionContext) marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -4739,11 +4739,11 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S return res } -func (ec *executionContext) marshalNUser2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v model.User) graphql.Marshaler { +func (ec *executionContext) marshalNUser2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v model.User) graphql.Marshaler { return ec._User(ctx, sel, &v) } -func (ec *executionContext) marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { +func (ec *executionContext) marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -5142,21 +5142,21 @@ func (ec *executionContext) marshalOBoolean2ᚖbool(ctx context.Context, sel ast return res } -func (ec *executionContext) marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐCat(ctx context.Context, sel ast.SelectionSet, v *model.Cat) graphql.Marshaler { +func (ec *executionContext) marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐCat(ctx context.Context, sel ast.SelectionSet, v *model.Cat) graphql.Marshaler { if v == nil { return graphql.Null } return ec._Cat(ctx, sel, v) } -func (ec *executionContext) marshalOHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx context.Context, sel ast.SelectionSet, v model.History) graphql.Marshaler { +func (ec *executionContext) marshalOHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx context.Context, sel ast.SelectionSet, v model.History) graphql.Marshaler { if v == nil { return graphql.Null } return ec._History(ctx, sel, v) } -func (ec *executionContext) marshalOHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx context.Context, sel ast.SelectionSet, v []model.History) graphql.Marshaler { +func (ec *executionContext) marshalOHistory2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx context.Context, sel ast.SelectionSet, v []model.History) graphql.Marshaler { if v == nil { return graphql.Null } @@ -5183,7 +5183,7 @@ func (ec *executionContext) marshalOHistory2ᚕgithubᚗcomᚋwundergraphᚋgrap if !isLen1 { defer wg.Done() } - ret[i] = ec.marshalOHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx, sel, v[i]) + ret[i] = ec.marshalOHistory2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐHistory(ctx, sel, v[i]) } if isLen1 { f(i) @@ -5197,7 +5197,7 @@ func (ec *executionContext) marshalOHistory2ᚕgithubᚗcomᚋwundergraphᚋgrap return ret } -func (ec *executionContext) marshalOIdentifiable2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐIdentifiable(ctx context.Context, sel ast.SelectionSet, v model.Identifiable) graphql.Marshaler { +func (ec *executionContext) marshalOIdentifiable2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐIdentifiable(ctx context.Context, sel ast.SelectionSet, v model.Identifiable) graphql.Marshaler { if v == nil { return graphql.Null } @@ -5230,14 +5230,14 @@ func (ec *executionContext) marshalOString2ᚖstring(ctx context.Context, sel as return res } -func (ec *executionContext) marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { +func (ec *executionContext) marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { if v == nil { return graphql.Null } return ec._User(ctx, sel, v) } -func (ec *executionContext) marshalOWallet2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐWallet(ctx context.Context, sel ast.SelectionSet, v model.Wallet) graphql.Marshaler { +func (ec *executionContext) marshalOWallet2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋaccountsᚋgraphᚋmodelᚐWallet(ctx context.Context, sel ast.SelectionSet, v model.Wallet) graphql.Marshaler { if v == nil { return graphql.Null } diff --git a/v2/pkg/testing/federationtesting/products/graph/generated/generated.go b/v2/pkg/testing/federationtesting/products/graph/generated/generated.go index 5d00b9025..755ab3bf4 100644 --- a/v2/pkg/testing/federationtesting/products/graph/generated/generated.go +++ b/v2/pkg/testing/federationtesting/products/graph/generated/generated.go @@ -39,6 +39,7 @@ type Config struct { type ResolverRoot interface { Entity() EntityResolver + Mutation() MutationResolver Query() QueryResolver Subscription() SubscriptionResolver } @@ -51,6 +52,10 @@ type ComplexityRoot struct { FindProductByUpc func(childComplexity int, upc string) int } + Mutation struct { + SetPrice func(childComplexity int, upc string, price int) int + } + Product struct { InStock func(childComplexity int) int Name func(childComplexity int) int @@ -77,6 +82,9 @@ type ComplexityRoot struct { type EntityResolver interface { FindProductByUpc(ctx context.Context, upc string) (*model.Product, error) } +type MutationResolver interface { + SetPrice(ctx context.Context, upc string, price int) (*model.Product, error) +} type QueryResolver interface { TopProducts(ctx context.Context, first *int) ([]*model.Product, error) } @@ -112,6 +120,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Entity.FindProductByUpc(childComplexity, args["upc"].(string)), true + case "Mutation.setPrice": + if e.complexity.Mutation.SetPrice == nil { + break + } + + args, err := ec.field_Mutation_setPrice_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.SetPrice(childComplexity, args["upc"].(string), args["price"].(int)), true + case "Product.inStock": if e.complexity.Product.InStock == nil { break @@ -219,6 +239,21 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { var buf bytes.Buffer data.MarshalGQL(&buf) + return &graphql.Response{ + Data: buf.Bytes(), + } + } + case ast.Mutation: + return func(ctx context.Context) *graphql.Response { + if !first { + return nil + } + first = false + ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap) + data := ec._Mutation(ctx, rc.Operation.SelectionSet) + var buf bytes.Buffer + data.MarshalGQL(&buf) + return &graphql.Response{ Data: buf.Bytes(), } @@ -270,6 +305,10 @@ var sources = []*ast.Source{ topProducts(first: Int = 5): [Product] } +extend type Mutation { + setPrice(upc: String!, price: Int!): Product +} + extend type Subscription { updatedPrice: Product! updateProductPrice(upc: String!): Product! @@ -280,7 +319,8 @@ type Product @key(fields: "upc") { name: String! price: Int! inStock: Int! -}`, BuiltIn: false}, +} +`, BuiltIn: false}, {Name: "../../federation/directives.graphql", Input: ` scalar _Any scalar _FieldSet @@ -333,6 +373,30 @@ func (ec *executionContext) field_Entity_findProductByUpc_args(ctx context.Conte return args, nil } +func (ec *executionContext) field_Mutation_setPrice_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["upc"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("upc")) + arg0, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["upc"] = arg0 + var arg1 int + if tmp, ok := rawArgs["price"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("price")) + arg1, err = ec.unmarshalNInt2int(ctx, tmp) + if err != nil { + return nil, err + } + } + args["price"] = arg1 + return args, nil +} + func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -459,7 +523,7 @@ func (ec *executionContext) _Entity_findProductByUpc(ctx context.Context, field } res := resTmp.(*model.Product) fc.Result = res - return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) + return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Entity_findProductByUpc(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -496,6 +560,68 @@ func (ec *executionContext) fieldContext_Entity_findProductByUpc(ctx context.Con return fc, nil } +func (ec *executionContext) _Mutation_setPrice(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Mutation_setPrice(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().SetPrice(rctx, fc.Args["upc"].(string), fc.Args["price"].(int)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*model.Product) + fc.Result = res + return ec.marshalOProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Mutation_setPrice(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Mutation", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "upc": + return ec.fieldContext_Product_upc(ctx, field) + case "name": + return ec.fieldContext_Product_name(ctx, field) + case "price": + return ec.fieldContext_Product_price(ctx, field) + case "inStock": + return ec.fieldContext_Product_inStock(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Product", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Mutation_setPrice_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return + } + return fc, nil +} + func (ec *executionContext) _Product_upc(ctx context.Context, field graphql.CollectedField, obj *model.Product) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Product_upc(ctx, field) if err != nil { @@ -697,7 +823,7 @@ func (ec *executionContext) _Query_topProducts(ctx context.Context, field graphq } res := resTmp.([]*model.Product) fc.Result = res - return ec.marshalOProduct2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) + return ec.marshalOProduct2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_topProducts(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1002,7 +1128,7 @@ func (ec *executionContext) _Subscription_updatedPrice(ctx context.Context, fiel w.Write([]byte{'{'}) graphql.MarshalString(field.Alias).MarshalGQL(w) w.Write([]byte{':'}) - ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res).MarshalGQL(w) + ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res).MarshalGQL(w) w.Write([]byte{'}'}) }) case <-ctx.Done(): @@ -1070,7 +1196,7 @@ func (ec *executionContext) _Subscription_updateProductPrice(ctx context.Context w.Write([]byte{'{'}) graphql.MarshalString(field.Alias).MarshalGQL(w) w.Write([]byte{':'}) - ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res).MarshalGQL(w) + ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res).MarshalGQL(w) w.Write([]byte{'}'}) }) case <-ctx.Done(): @@ -3004,6 +3130,42 @@ func (ec *executionContext) _Entity(ctx context.Context, sel ast.SelectionSet) g return out } +var mutationImplementors = []string{"Mutation"} + +func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, mutationImplementors) + ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{ + Object: "Mutation", + }) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + innerCtx := graphql.WithRootFieldContext(ctx, &graphql.RootFieldContext{ + Object: field.Name, + Field: field, + }) + + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("Mutation") + case "setPrice": + + out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { + return ec._Mutation_setPrice(ctx, field) + }) + + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var productImplementors = []string{"Product", "_Entity"} func (ec *executionContext) _Product(ctx context.Context, sel ast.SelectionSet, obj *model.Product) graphql.Marshaler { @@ -3556,11 +3718,11 @@ func (ec *executionContext) marshalNInt2int(ctx context.Context, sel ast.Selecti return res } -func (ec *executionContext) marshalNProduct2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v model.Product) graphql.Marshaler { +func (ec *executionContext) marshalNProduct2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v model.Product) graphql.Marshaler { return ec._Product(ctx, sel, &v) } -func (ec *executionContext) marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { +func (ec *executionContext) marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -3990,7 +4152,7 @@ func (ec *executionContext) marshalOInt2ᚖint(ctx context.Context, sel ast.Sele return res } -func (ec *executionContext) marshalOProduct2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v []*model.Product) graphql.Marshaler { +func (ec *executionContext) marshalOProduct2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v []*model.Product) graphql.Marshaler { if v == nil { return graphql.Null } @@ -4017,7 +4179,7 @@ func (ec *executionContext) marshalOProduct2ᚕᚖgithubᚗcomᚋwundergraphᚋg if !isLen1 { defer wg.Done() } - ret[i] = ec.marshalOProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, sel, v[i]) + ret[i] = ec.marshalOProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx, sel, v[i]) } if isLen1 { f(i) @@ -4031,7 +4193,7 @@ func (ec *executionContext) marshalOProduct2ᚕᚖgithubᚗcomᚋwundergraphᚋg return ret } -func (ec *executionContext) marshalOProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { +func (ec *executionContext) marshalOProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋproductsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { if v == nil { return graphql.Null } diff --git a/v2/pkg/testing/federationtesting/products/graph/schema.resolvers.go b/v2/pkg/testing/federationtesting/products/graph/schema.resolvers.go index b85da7eb6..c714c7972 100644 --- a/v2/pkg/testing/federationtesting/products/graph/schema.resolvers.go +++ b/v2/pkg/testing/federationtesting/products/graph/schema.resolvers.go @@ -14,6 +14,11 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/federationtesting/products/graph/model" ) +// SetPrice is the resolver for the setPrice field. +func (r *mutationResolver) SetPrice(ctx context.Context, upc string, price int) (*model.Product, error) { + panic(fmt.Errorf("not implemented: SetPrice - setPrice")) +} + // TopProducts is the resolver for the topProducts field. func (r *queryResolver) TopProducts(ctx context.Context, first *int) ([]*model.Product, error) { return hats, nil @@ -80,11 +85,15 @@ func (r *subscriptionResolver) UpdateProductPrice(ctx context.Context, upc strin return updatedPrice, nil } +// Mutation returns generated.MutationResolver implementation. +func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResolver{r} } + // Query returns generated.QueryResolver implementation. func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} } // Subscription returns generated.SubscriptionResolver implementation. func (r *Resolver) Subscription() generated.SubscriptionResolver { return &subscriptionResolver{r} } +type mutationResolver struct{ *Resolver } type queryResolver struct{ *Resolver } type subscriptionResolver struct{ *Resolver } diff --git a/v2/pkg/testing/federationtesting/reviews/graph/generated/generated.go b/v2/pkg/testing/federationtesting/reviews/graph/generated/generated.go index 8a8046d92..318785cba 100644 --- a/v2/pkg/testing/federationtesting/reviews/graph/generated/generated.go +++ b/v2/pkg/testing/federationtesting/reviews/graph/generated/generated.go @@ -712,7 +712,7 @@ func (ec *executionContext) _Entity_findProductByUpc(ctx context.Context, field } res := resTmp.(*model.Product) fc.Result = res - return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) + return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Entity_findProductByUpc(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -773,7 +773,7 @@ func (ec *executionContext) _Entity_findUserByID(ctx context.Context, field grap } res := resTmp.(*model.User) fc.Result = res - return ec.marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) + return ec.marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Entity_findUserByID(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -838,7 +838,7 @@ func (ec *executionContext) _Mutation_addReview(ctx context.Context, field graph } res := resTmp.(*model.Review) fc.Result = res - return ec.marshalNReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, field.Selections, res) + return ec.marshalNReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Mutation_addReview(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -944,7 +944,7 @@ func (ec *executionContext) _Product_reviews(ctx context.Context, field graphql. } res := resTmp.([]*model.Review) fc.Result = res - return ec.marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, field.Selections, res) + return ec.marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Product_reviews(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -995,7 +995,7 @@ func (ec *executionContext) _Query_me(ctx context.Context, field graphql.Collect } res := resTmp.(*model.User) fc.Result = res - return ec.marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) + return ec.marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_me(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1046,7 +1046,7 @@ func (ec *executionContext) _Query_cat(ctx context.Context, field graphql.Collec } res := resTmp.(*model.Cat) fc.Result = res - return ec.marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐCat(ctx, field.Selections, res) + return ec.marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐCat(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_cat(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1590,7 +1590,7 @@ func (ec *executionContext) _Review_author(ctx context.Context, field graphql.Co } res := resTmp.(*model.User) fc.Result = res - return ec.marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) + return ec.marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Review_author(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1644,7 +1644,7 @@ func (ec *executionContext) _Review_product(ctx context.Context, field graphql.C } res := resTmp.(*model.Product) fc.Result = res - return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) + return ec.marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Review_product(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1691,7 +1691,7 @@ func (ec *executionContext) _Review_attachments(ctx context.Context, field graph } res := resTmp.([]model.Attachment) fc.Result = res - return ec.marshalOAttachment2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx, field.Selections, res) + return ec.marshalOAttachment2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Review_attachments(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1820,7 +1820,7 @@ func (ec *executionContext) _User_reviews(ctx context.Context, field graphql.Col } res := resTmp.([]*model.Review) fc.Result = res - return ec.marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, field.Selections, res) + return ec.marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_User_reviews(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -4851,11 +4851,11 @@ func (ec *executionContext) marshalNInt2int(ctx context.Context, sel ast.Selecti return res } -func (ec *executionContext) marshalNProduct2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v model.Product) graphql.Marshaler { +func (ec *executionContext) marshalNProduct2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v model.Product) graphql.Marshaler { return ec._Product(ctx, sel, &v) } -func (ec *executionContext) marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { +func (ec *executionContext) marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐProduct(ctx context.Context, sel ast.SelectionSet, v *model.Product) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -4865,11 +4865,11 @@ func (ec *executionContext) marshalNProduct2ᚖgithubᚗcomᚋwundergraphᚋgrap return ec._Product(ctx, sel, v) } -func (ec *executionContext) marshalNReview2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v model.Review) graphql.Marshaler { +func (ec *executionContext) marshalNReview2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v model.Review) graphql.Marshaler { return ec._Review(ctx, sel, &v) } -func (ec *executionContext) marshalNReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v *model.Review) graphql.Marshaler { +func (ec *executionContext) marshalNReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v *model.Review) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -4894,11 +4894,11 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S return res } -func (ec *executionContext) marshalNUser2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v model.User) graphql.Marshaler { +func (ec *executionContext) marshalNUser2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v model.User) graphql.Marshaler { return ec._User(ctx, sel, &v) } -func (ec *executionContext) marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { +func (ec *executionContext) marshalNUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -5271,14 +5271,14 @@ func (ec *executionContext) marshalN__TypeKind2string(ctx context.Context, sel a return res } -func (ec *executionContext) marshalOAttachment2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx context.Context, sel ast.SelectionSet, v model.Attachment) graphql.Marshaler { +func (ec *executionContext) marshalOAttachment2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx context.Context, sel ast.SelectionSet, v model.Attachment) graphql.Marshaler { if v == nil { return graphql.Null } return ec._Attachment(ctx, sel, v) } -func (ec *executionContext) marshalOAttachment2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx context.Context, sel ast.SelectionSet, v []model.Attachment) graphql.Marshaler { +func (ec *executionContext) marshalOAttachment2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx context.Context, sel ast.SelectionSet, v []model.Attachment) graphql.Marshaler { if v == nil { return graphql.Null } @@ -5305,7 +5305,7 @@ func (ec *executionContext) marshalOAttachment2ᚕgithubᚗcomᚋwundergraphᚋg if !isLen1 { defer wg.Done() } - ret[i] = ec.marshalOAttachment2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx, sel, v[i]) + ret[i] = ec.marshalOAttachment2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐAttachment(ctx, sel, v[i]) } if isLen1 { f(i) @@ -5345,14 +5345,14 @@ func (ec *executionContext) marshalOBoolean2ᚖbool(ctx context.Context, sel ast return res } -func (ec *executionContext) marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐCat(ctx context.Context, sel ast.SelectionSet, v *model.Cat) graphql.Marshaler { +func (ec *executionContext) marshalOCat2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐCat(ctx context.Context, sel ast.SelectionSet, v *model.Cat) graphql.Marshaler { if v == nil { return graphql.Null } return ec._Cat(ctx, sel, v) } -func (ec *executionContext) marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v []*model.Review) graphql.Marshaler { +func (ec *executionContext) marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v []*model.Review) graphql.Marshaler { if v == nil { return graphql.Null } @@ -5379,7 +5379,7 @@ func (ec *executionContext) marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgr if !isLen1 { defer wg.Done() } - ret[i] = ec.marshalOReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, sel, v[i]) + ret[i] = ec.marshalOReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx, sel, v[i]) } if isLen1 { f(i) @@ -5393,7 +5393,7 @@ func (ec *executionContext) marshalOReview2ᚕᚖgithubᚗcomᚋwundergraphᚋgr return ret } -func (ec *executionContext) marshalOReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v *model.Review) graphql.Marshaler { +func (ec *executionContext) marshalOReview2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐReview(ctx context.Context, sel ast.SelectionSet, v *model.Review) graphql.Marshaler { if v == nil { return graphql.Null } @@ -5426,7 +5426,7 @@ func (ec *executionContext) marshalOString2ᚖstring(ctx context.Context, sel as return res } -func (ec *executionContext) marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { +func (ec *executionContext) marshalOUser2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋfederationtestingᚋreviewsᚋgraphᚋmodelᚐUser(ctx context.Context, sel ast.SelectionSet, v *model.User) graphql.Marshaler { if v == nil { return graphql.Null } diff --git a/v2/pkg/testing/subscriptiontesting/generated.go b/v2/pkg/testing/subscriptiontesting/generated.go index 4f6260726..72b798e3e 100644 --- a/v2/pkg/testing/subscriptiontesting/generated.go +++ b/v2/pkg/testing/subscriptiontesting/generated.go @@ -520,7 +520,7 @@ func (ec *executionContext) _Chatroom_messages(ctx context.Context, field graphq } res := resTmp.([]Message) fc.Result = res - return ec.marshalNMessage2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐMessageᚄ(ctx, field.Selections, res) + return ec.marshalNMessage2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐMessageᚄ(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Chatroom_messages(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -750,7 +750,7 @@ func (ec *executionContext) _Mutation_post(ctx context.Context, field graphql.Co } res := resTmp.(*Message) fc.Result = res - return ec.marshalNMessage2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx, field.Selections, res) + return ec.marshalNMessage2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Mutation_post(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -812,7 +812,7 @@ func (ec *executionContext) _Query_room(ctx context.Context, field graphql.Colle } res := resTmp.(*Chatroom) fc.Result = res - return ec.marshalOChatroom2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐChatroom(ctx, field.Selections, res) + return ec.marshalOChatroom2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐChatroom(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_room(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1010,7 +1010,7 @@ func (ec *executionContext) _Subscription_messageAdded(ctx context.Context, fiel w.Write([]byte{'{'}) graphql.MarshalString(field.Alias).MarshalGQL(w) w.Write([]byte{':'}) - ec.marshalNMessage2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx, field.Selections, res).MarshalGQL(w) + ec.marshalNMessage2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx, field.Selections, res).MarshalGQL(w) w.Write([]byte{'}'}) }) case <-ctx.Done(): @@ -3387,11 +3387,11 @@ func (ec *executionContext) marshalNID2string(ctx context.Context, sel ast.Selec return res } -func (ec *executionContext) marshalNMessage2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx context.Context, sel ast.SelectionSet, v Message) graphql.Marshaler { +func (ec *executionContext) marshalNMessage2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx context.Context, sel ast.SelectionSet, v Message) graphql.Marshaler { return ec._Message(ctx, sel, &v) } -func (ec *executionContext) marshalNMessage2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐMessageᚄ(ctx context.Context, sel ast.SelectionSet, v []Message) graphql.Marshaler { +func (ec *executionContext) marshalNMessage2ᚕgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐMessageᚄ(ctx context.Context, sel ast.SelectionSet, v []Message) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup isLen1 := len(v) == 1 @@ -3415,7 +3415,7 @@ func (ec *executionContext) marshalNMessage2ᚕgithubᚗcomᚋwundergraphᚋgrap if !isLen1 { defer wg.Done() } - ret[i] = ec.marshalNMessage2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx, sel, v[i]) + ret[i] = ec.marshalNMessage2githubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx, sel, v[i]) } if isLen1 { f(i) @@ -3435,7 +3435,7 @@ func (ec *executionContext) marshalNMessage2ᚕgithubᚗcomᚋwundergraphᚋgrap return ret } -func (ec *executionContext) marshalNMessage2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx context.Context, sel ast.SelectionSet, v *Message) graphql.Marshaler { +func (ec *executionContext) marshalNMessage2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐMessage(ctx context.Context, sel ast.SelectionSet, v *Message) graphql.Marshaler { if v == nil { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -3754,7 +3754,7 @@ func (ec *executionContext) marshalOBoolean2ᚖbool(ctx context.Context, sel ast return res } -func (ec *executionContext) marshalOChatroom2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋpkgᚋtestingᚋsubscriptiontestingᚐChatroom(ctx context.Context, sel ast.SelectionSet, v *Chatroom) graphql.Marshaler { +func (ec *executionContext) marshalOChatroom2ᚖgithubᚗcomᚋwundergraphᚋgraphqlᚑgoᚑtoolsᚋv2ᚋpkgᚋtestingᚋsubscriptiontestingᚐChatroom(ctx context.Context, sel ast.SelectionSet, v *Chatroom) graphql.Marshaler { if v == nil { return graphql.Null } From e38d33502c1ccefc94a2fe117b14e8c988680fae Mon Sep 17 00:00:00 2001 From: fiam Date: Thu, 5 Oct 2023 11:35:29 +0100 Subject: [PATCH 4/5] fix: serialize graphql messages appropriately, omiting variables when empty Fix tests too --- go.mod | 3 +-- go.sum | 17 +++++++++-------- pkg/graphql/request.go | 2 +- .../protocol_graphql_transport_ws_test.go | 11 +++++++---- v2/pkg/graphql/request.go | 2 +- .../protocol_graphql_transport_ws_test.go | 10 ++++++---- 6 files changed, 25 insertions(+), 20 deletions(-) diff --git a/go.mod b/go.mod index 7825ed65e..a622a5b6a 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/evanphx/json-patch/v5 v5.1.0 github.com/go-test/deep v1.0.8 github.com/gobwas/ws v1.0.4 - github.com/golang/mock v1.4.1 + github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.8 github.com/gorilla/websocket v1.5.0 github.com/hashicorp/golang-lru v0.5.4 @@ -25,7 +25,6 @@ require ( github.com/nats-io/nats.go v1.19.1 github.com/r3labs/sse/v2 v2.8.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 - github.com/sashabaranov/go-openai v1.14.1 github.com/sebdah/goldie/v2 v2.5.3 github.com/spf13/cobra v0.0.5 github.com/spf13/viper v1.3.2 diff --git a/go.sum b/go.sum index 0fa6da673..7d4c50616 100644 --- a/go.sum +++ b/go.sum @@ -64,8 +64,8 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs= github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/mock v1.4.1 h1:ocYkMQY5RrXTYgXl7ICpV0IXwlEQGwKIsery4gyXa1U= -github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= @@ -169,8 +169,6 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 h1:uIkTLo0AGRc8l7h5l9r+GcYi9qfVPt6lD4/bhmzfiKo= github.com/santhosh-tekuri/jsonschema/v5 v5.3.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0= -github.com/sashabaranov/go-openai v1.14.1 h1:jqfkdj8XHnBF84oi2aNtT8Ktp3EJ0MfuVjvcMkfI0LA= -github.com/sashabaranov/go-openai v1.14.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -218,6 +216,7 @@ github.com/vektah/gqlparser/v2 v2.5.1 h1:ZGu+bquAY23jsxDRcYpWjttRZrUz07LbiY77gUO github.com/vektah/gqlparser/v2 v2.5.1/go.mod h1:mPgqFBu/woKTVYWyNk8cO3kh4S/f4aRFZrvOnp3hmCs= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -237,6 +236,7 @@ golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -248,6 +248,7 @@ golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -255,6 +256,7 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= @@ -270,7 +272,9 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -280,7 +284,6 @@ golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -293,13 +296,13 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= @@ -331,5 +334,3 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/pkg/graphql/request.go b/pkg/graphql/request.go index 9a5d0e2c0..4430235e4 100644 --- a/pkg/graphql/request.go +++ b/pkg/graphql/request.go @@ -35,7 +35,7 @@ var ( type Request struct { OperationName string `json:"operationName"` - Variables json.RawMessage `json:"variables"` + Variables json.RawMessage `json:"variables,omitempty"` Query string `json:"query"` document ast.Document diff --git a/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go b/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go index 56440880d..5fe382f8d 100644 --- a/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go +++ b/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go @@ -475,13 +475,17 @@ func TestProtocolGraphQLTransportWSHandler_Handle(t *testing.T) { operation := []byte(`{"operationName":"Hello","query":"query Hello { hello }"}`) ctrl := gomock.NewController(t) + defer ctrl.Finish() mockEngine := NewMockEngine(ctrl) - mockEngine.EXPECT().StartOperation(gomock.Eq(ctx), gomock.Eq("1"), gomock.Eq(operation), gomock.Eq(&protocol.eventHandler)) + mockEngine.EXPECT().StartOperation(gomock.Eq(ctx), gomock.Eq("2"), gomock.Eq(operation), gomock.Eq(&protocol.eventHandler)) assert.Eventually(t, func() bool { - inputMessage := []byte(`{"id":"1","type":"subscribe","payload":{"operationName":"Hello","query":"query Hello { hello }"}}`) - err := protocol.Handle(ctx, mockEngine, inputMessage) + initMessage := []byte(`{"id":"1","type":"connection_init"}`) + err := protocol.Handle(ctx, mockEngine, initMessage) assert.NoError(t, err) + subscribeMessage := []byte(`{"id":"2","type":"subscribe","payload":` + string(operation) + `}`) + err2 := protocol.Handle(ctx, mockEngine, subscribeMessage) + assert.NoError(t, err2) return true }, 1*time.Second, 2*time.Millisecond) }) @@ -533,7 +537,6 @@ func TestProtocolGraphQLTransportWSHandler_Handle(t *testing.T) { ctrl := gomock.NewController(t) mockEngine := NewMockEngine(ctrl) - mockEngine.EXPECT().StopSubscription(gomock.Eq("1"), gomock.Eq(&protocol.eventHandler)) assert.Eventually(t, func() bool { inputMessage := []byte(`{"type":"connection_init","payload":{something}}`) diff --git a/v2/pkg/graphql/request.go b/v2/pkg/graphql/request.go index cf9c2cbb3..476f1b7df 100644 --- a/v2/pkg/graphql/request.go +++ b/v2/pkg/graphql/request.go @@ -34,7 +34,7 @@ var ( type Request struct { OperationName string `json:"operationName"` - Variables json.RawMessage `json:"variables"` + Variables json.RawMessage `json:"variables,omitempty"` Query string `json:"query"` document ast.Document diff --git a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go index e01e1c513..caf32095e 100644 --- a/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go +++ b/v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go @@ -476,12 +476,15 @@ func TestProtocolGraphQLTransportWSHandler_Handle(t *testing.T) { operation := []byte(`{"operationName":"Hello","query":"query Hello { hello }"}`) ctrl := gomock.NewController(t) mockEngine := NewMockEngine(ctrl) - mockEngine.EXPECT().StartOperation(gomock.Eq(ctx), gomock.Eq("1"), gomock.Eq(operation), gomock.Eq(&protocol.eventHandler)) + mockEngine.EXPECT().StartOperation(gomock.Eq(ctx), gomock.Eq("2"), gomock.Eq(operation), gomock.Eq(&protocol.eventHandler)) assert.Eventually(t, func() bool { - inputMessage := []byte(`{"id":"1","type":"subscribe","payload":{"operationName":"Hello","query":"query Hello { hello }"}}`) - err := protocol.Handle(ctx, mockEngine, inputMessage) + initMessage := []byte(`{"id":"1","type":"connection_init"}`) + err := protocol.Handle(ctx, mockEngine, initMessage) assert.NoError(t, err) + subscribeMessage := []byte(`{"id":"2","type":"subscribe","payload":` + string(operation) + `}`) + err2 := protocol.Handle(ctx, mockEngine, subscribeMessage) + assert.NoError(t, err2) return true }, 1*time.Second, 2*time.Millisecond) }) @@ -533,7 +536,6 @@ func TestProtocolGraphQLTransportWSHandler_Handle(t *testing.T) { ctrl := gomock.NewController(t) mockEngine := NewMockEngine(ctrl) - mockEngine.EXPECT().StopSubscription(gomock.Eq("1"), gomock.Eq(&protocol.eventHandler)) assert.Eventually(t, func() bool { inputMessage := []byte(`{"type":"connection_init","payload":{something}}`) From eaa71ec8f3e31952f13c4edf7c5276dd84a35f09 Mon Sep 17 00:00:00 2001 From: fiam Date: Thu, 5 Oct 2023 11:49:33 +0100 Subject: [PATCH 5/5] chore: disable test on Windows --- v2/pkg/subscription/websocket/handler_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/v2/pkg/subscription/websocket/handler_test.go b/v2/pkg/subscription/websocket/handler_test.go index 0f068aebe..d366848e0 100644 --- a/v2/pkg/subscription/websocket/handler_test.go +++ b/v2/pkg/subscription/websocket/handler_test.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/httptest" + "runtime" "testing" "time" @@ -24,6 +25,9 @@ import ( func TestHandleWithOptions(t *testing.T) { t.Run("should handle protocol graphql-ws", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("this test fails on Windows due to different timings than unix, consider fixing it at some point") + } chatServer := httptest.NewServer(subscriptiontesting.ChatGraphQLEndpointHandler()) defer chatServer.Close()