Skip to content

Commit

Permalink
fix flaky test in v2
Browse files Browse the repository at this point in the history
  • Loading branch information
pvormste committed Oct 30, 2023
1 parent 6d5b337 commit 61b327d
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 34 deletions.
2 changes: 1 addition & 1 deletion v2/pkg/subscription/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (u *UniversalProtocolHandler) Handle(ctx context.Context) {
TimeOutAction: func() {
cancel() // stop the handler if timer runs out
},
TimeOutDuration: u.readErrorTimeOut,
Timer: time.NewTimer(u.readErrorTimeOut),
}
go TimeOutChecker(params)
u.isReadTimeOutTimerRunning = true
Expand Down
21 changes: 13 additions & 8 deletions v2/pkg/subscription/time_out.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,30 @@ import (

// TimeOutParams is a struct to configure a TimeOutChecker.
type TimeOutParams struct {
Name string
Logger abstractlogger.Logger
TimeOutContext context.Context
TimeOutAction func()
TimeOutDuration time.Duration
Name string
Logger abstractlogger.Logger
TimeOutContext context.Context
TimeOutAction func()
Timer *time.Timer
}

// TimeOutChecker is a function that can be used in a go routine to perform a time-out action
// after a specific duration or prevent the time-out action by canceling the time-out context before.
// Use TimeOutParams for configuration.
func TimeOutChecker(params TimeOutParams) {
timer := time.NewTimer(params.TimeOutDuration)
defer timer.Stop()
if params.Timer == nil {
params.Logger.Error("timer is nil",
abstractlogger.String("name", params.Name),
)
return
}
defer params.Timer.Stop()

for {
select {
case <-params.TimeOutContext.Done():
return
case <-timer.C:
case <-params.Timer.C:
params.Logger.Error("time out happened",
abstractlogger.String("name", params.Name),
)
Expand Down
20 changes: 10 additions & 10 deletions v2/pkg/subscription/time_out_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ func TestTimeOutChecker(t *testing.T) {

timeOutCtx, timeOutCancel := context.WithCancel(context.Background())
params := TimeOutParams{
Name: "",
Logger: abstractlogger.Noop{},
TimeOutContext: timeOutCtx,
TimeOutAction: timeOutAction,
TimeOutDuration: 100 * time.Millisecond,
Name: "",
Logger: abstractlogger.Noop{},
TimeOutContext: timeOutCtx,
TimeOutAction: timeOutAction,
Timer: time.NewTimer(100 * time.Millisecond),
}
go TimeOutChecker(params)
time.Sleep(5 * time.Millisecond)
Expand All @@ -46,11 +46,11 @@ func TestTimeOutChecker(t *testing.T) {
defer timeOutCancel()

params := TimeOutParams{
Name: "",
Logger: abstractlogger.Noop{},
TimeOutContext: timeOutCtx,
TimeOutAction: timeOutAction,
TimeOutDuration: 10 * time.Millisecond,
Name: "",
Logger: abstractlogger.Noop{},
TimeOutContext: timeOutCtx,
TimeOutAction: timeOutAction,
Timer: time.NewTimer(10 * time.Millisecond),
}
go TimeOutChecker(params)
wg.Wait()
Expand Down
15 changes: 10 additions & 5 deletions v2/pkg/subscription/websocket/protocol_graphql_transport_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,17 @@ func NewProtocolGraphQLTransportWSHandlerWithOptions(client subscription.Transpo
}

// Pass event functions
protocolHandler.eventHandler.OnConnectionOpened = protocolHandler.startConnectionInitTimer
protocolHandler.eventHandler.OnConnectionOpened = func() {
protocolHandler.startConnectionInitTimer(time.NewTimer(protocolHandler.connectionInitTimeOutDuration))
}

return protocolHandler, nil
}

// Handle will handle the actual graphql-transport-ws protocol messages. It's an implementation of subscription.Protocol.
func (p *ProtocolGraphQLTransportWSHandler) Handle(ctx context.Context, engine subscription.Engine, data []byte) error {
if !p.connectionAcknowledged && !p.connectionInitTimerStarted {
p.startConnectionInitTimer()
p.startConnectionInitTimer(time.NewTimer(p.connectionInitTimeOutDuration))
}

message, err := p.reader.Read(data)
Expand Down Expand Up @@ -393,11 +395,12 @@ func (p *ProtocolGraphQLTransportWSHandler) EventHandler() subscription.EventHan
return &p.eventHandler
}

func (p *ProtocolGraphQLTransportWSHandler) startConnectionInitTimer() {
func (p *ProtocolGraphQLTransportWSHandler) startConnectionInitTimer(timer *time.Timer) context.Context {
if p.connectionInitTimerStarted {
return
return context.Background()
}

timeOutActionContext, timeOutActionContextCancel := context.WithCancel(context.Background())
timeOutContext, timeOutContextCancel := context.WithCancel(context.Background())
p.connectionInitTimeOutCancel = timeOutContextCancel
p.connectionInitTimerStarted = true
Expand All @@ -409,10 +412,12 @@ func (p *ProtocolGraphQLTransportWSHandler) startConnectionInitTimer() {
p.closeConnectionWithReason(
NewCloseReason(4408, "Connection initialisation timeout"),
)
timeOutActionContextCancel()
},
TimeOutDuration: p.connectionInitTimeOutDuration,
Timer: timer,
}
go subscription.TimeOutChecker(timeOutParams)
return timeOutActionContext
}

func (p *ProtocolGraphQLTransportWSHandler) stopConnectionInitTimer() bool {
Expand Down
28 changes: 18 additions & 10 deletions v2/pkg/subscription/websocket/protocol_graphql_transport_ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,25 +368,31 @@ func TestProtocolGraphQLTransportWSHandler_Handle(t *testing.T) {

t.Run("for connection_init", func(t *testing.T) {
t.Run("should time out if no connection_init message is sent", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("this test fails on Windows due to different timings than unix, consider fixing it at some point")
}
testClient := NewTestClient(false)
protocol := NewTestProtocolGraphQLTransportWSHandler(testClient)
protocol.connectionInitTimeOutDuration = 2 * time.Millisecond
protocol.eventHandler.OnConnectionOpened = protocol.startConnectionInitTimer

var timeOutActionContext context.Context
protocol.connectionInitTimeOutDuration = 5 * time.Millisecond
protocol.eventHandler.OnConnectionOpened = func() {
timeOutActionContext = protocol.startConnectionInitTimer(time.NewTimer(protocol.connectionInitTimeOutDuration))
}

protocol.eventHandler.Emit(subscription.EventTypeOnConnectionOpened, "", nil, nil)
time.Sleep(10 * time.Millisecond)
assert.True(t, protocol.connectionInitTimerStarted)
assert.False(t, protocol.eventHandler.Writer.Client.IsConnected())
assert.Eventuallyf(t, func() bool {
<-timeOutActionContext.Done()
assert.True(t, protocol.connectionInitTimerStarted)
assert.False(t, protocol.eventHandler.Writer.Client.IsConnected())
return true
}, 1*time.Second, 2*time.Millisecond, "connection_init timer did not time out")
})

t.Run("should close connection after multiple connection_init messages", func(t *testing.T) {
testClient := NewTestClient(false)
protocol := NewTestProtocolGraphQLTransportWSHandler(testClient)
protocol.connectionInitTimeOutDuration = 50 * time.Millisecond
protocol.eventHandler.OnConnectionOpened = protocol.startConnectionInitTimer
protocol.eventHandler.OnConnectionOpened = func() {
protocol.startConnectionInitTimer(time.NewTimer(protocol.connectionInitTimeOutDuration))
}

ctrl := gomock.NewController(t)
mockEngine := NewMockEngine(ctrl)
Expand Down Expand Up @@ -418,7 +424,9 @@ func TestProtocolGraphQLTransportWSHandler_Handle(t *testing.T) {
protocol := NewTestProtocolGraphQLTransportWSHandler(testClient)
protocol.heartbeatInterval = 4 * time.Millisecond
protocol.connectionInitTimeOutDuration = 25 * time.Millisecond
protocol.eventHandler.OnConnectionOpened = protocol.startConnectionInitTimer
protocol.eventHandler.OnConnectionOpened = func() {
protocol.startConnectionInitTimer(time.NewTimer(protocol.connectionInitTimeOutDuration))
}

ctrl := gomock.NewController(t)
mockEngine := NewMockEngine(ctrl)
Expand Down

0 comments on commit 61b327d

Please sign in to comment.