From 470cae5e91ddbf5537a56af848eb23790325c98e Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Thu, 14 Nov 2024 12:55:30 +0400 Subject: [PATCH 01/30] TaskGroup goroutine manager added in pkg/execution package. Tests on TaskGroup added. --- go.mod | 1 + pkg/execution/taskgroup.go | 121 ++++++++++++++++++++++ pkg/execution/taskgroup_test.go | 173 ++++++++++++++++++++++++++++++++ 3 files changed, 295 insertions(+) create mode 100644 pkg/execution/taskgroup.go create mode 100644 pkg/execution/taskgroup_test.go diff --git a/go.mod b/go.mod index 93eefdc5f..b946f41e2 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 github.com/xenolf/lego v2.7.2+incompatible go.uber.org/atomic v1.11.0 + go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 golang.org/x/crypto v0.29.0 golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e diff --git a/pkg/execution/taskgroup.go b/pkg/execution/taskgroup.go new file mode 100644 index 000000000..387d4e007 --- /dev/null +++ b/pkg/execution/taskgroup.go @@ -0,0 +1,121 @@ +package execution + +import ( + "sync" + "sync/atomic" +) + +// A TaskGroup manages a collection of cooperating goroutines. Add new tasks to the group with the Run method. +// Call the Wait method to wait for the tasks to complete. +// A zero value is ready for use, but must not be copied after its first use. +// +// The group collects any errors returned by the tasks in the group. +// The first non-nil error reported by any execution and not filtered is returned from the Wait method. +type TaskGroup struct { + wg sync.WaitGroup // Counter for active goroutines. + + // active is nonzero when the group is "active", meaning there has been at least one call to Run since the group + // was created or the last Wait. + // + // Together active and errLock work as a kind of resettable sync.Once. The fast path reads active and only + // acquires errLock if it discovers setup is needed. + active atomic.Uint32 + + errLock sync.Mutex // Guards the fields below. + err error // First captured error returned from Wait. + onError errorFunc // Called each time a task returns non-nil error. +} + +// NewTaskGroup constructs a new empty group with the specified error handler. +// See [TaskGroup.OnError] for a description of how errors are filtered. If handler is nil, no filtering is performed. +// Main properties of the TaskGroup are: +// - Cancel propagation. +// - Error propagation. +// - Waiting for all tasks to finish. +func NewTaskGroup(handler func(error) error) *TaskGroup { + return new(TaskGroup).OnError(handler) +} + +// OnError sets the error handler for TaskGroup. If handler is nil, +// the error handler is removed and errors are no longer filtered. Otherwise, each non-nil error reported by an +// execution running in g is passed to handler. +// +// Then handler is called with each reported error, and its result replaces the reported value. This permits handler to +// suppress or replace the error value selectively. +// +// Calls to handler are synchronized so that it is safe for handler to manipulate local data structures without +// additional locking. It is safe to call OnError while tasks are active in TaskGroup. +func (g *TaskGroup) OnError(handler func(error) error) *TaskGroup { + g.errLock.Lock() + defer g.errLock.Unlock() + g.onError = handler + return g +} + +// Run starts an [execute] function in a new goroutine in [TaskGroup]. The execution is not interrupted by TaskGroup, +// so the [execute] function should include the interruption logic. +func (g *TaskGroup) Run(execute func() error) { + g.wg.Add(1) + if g.active.Load() == 0 { + g.activate() + } + go func() { + defer g.wg.Done() + if err := execute(); err != nil { + g.handleError(err) + } + }() +} + +// Wait blocks until all the goroutines currently active in the TaskGroup have returned, and all reported errors have +// been delivered to the handler. It returns the first non-nil error reported by any of the goroutines in the group and +// not filtered by an OnError handler. +// +// As with sync.WaitGroup, new tasks can be added to TaskGroup during a Wait call only if the TaskGroup contains at +// least one active execution when Wait is called and continuously thereafter until the last concurrent call to +// Run returns. +// +// Wait may be called from at most one goroutine at a time. After Wait has returned, the group is ready for reuse. +func (g *TaskGroup) Wait() error { + g.wg.Wait() + g.errLock.Lock() + defer g.errLock.Unlock() + + // If the group is still active, deactivate it now. + if g.active.Load() != 0 { + g.active.Store(0) + } + return g.err +} + +// activate resets the state of the group and marks it as "active". This is triggered by adding a goroutine to +// an empty group. +func (g *TaskGroup) activate() { + g.errLock.Lock() + defer g.errLock.Unlock() + if g.active.Load() == 0 { + g.err = nil + g.active.Store(1) + } +} + +// handleError synchronizes access to the error handler and captures the first non-nil error. +func (g *TaskGroup) handleError(err error) { + g.errLock.Lock() + defer g.errLock.Unlock() + e := g.onError.filter(err) + if e != nil && g.err == nil { + g.err = e // Capture the first unfiltered error. + } +} + +// An errorFunc is called by a group each time an execution reports an error. Its return value replaces the reported +// error, so the errorFunc can filter or suppress errors by modifying or discarding the input error. +type errorFunc func(error) error + +func (ef errorFunc) filter(err error) error { + if ef == nil { + return err + } + return ef(err) +} diff --git a/pkg/execution/taskgroup_test.go b/pkg/execution/taskgroup_test.go new file mode 100644 index 000000000..bacedb3fc --- /dev/null +++ b/pkg/execution/taskgroup_test.go @@ -0,0 +1,173 @@ +package execution_test + +import ( + "context" + "errors" + "math/rand/v2" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/wavesplatform/gowaves/pkg/execution" +) + +func TestBasic(t *testing.T) { + defer goleak.VerifyNone(t) + + // Verify that the group works at all. + var g execution.TaskGroup + g.Run(work(25, nil)) + err := g.Wait() + require.NoError(t, err) + + // Verify that the group can be reused. + g.Run(work(50, nil)) + g.Run(work(75, nil)) + err = g.Wait() + require.NoError(t, err) + + // Verify that error is propagated without an error handler. + g.Run(work(50, errors.New("expected error"))) + err = g.Wait() + require.Error(t, err) +} + +func TestErrorsPropagation(t *testing.T) { + defer goleak.VerifyNone(t) + + expected := errors.New("expected error") + + var g execution.TaskGroup + g.Run(func() error { return expected }) + err := g.Wait() + require.ErrorIs(t, err, expected) + + g.OnError(func(error) error { return nil }) // discard all error + g.Run(func() error { return expected }) + err = g.Wait() + require.NoError(t, err) +} + +func TestCancelPropagation(t *testing.T) { + defer goleak.VerifyNone(t) + + const numTasks = 64 + + var errs []error + g := execution.NewTaskGroup(func(err error) error { + errs = append(errs, err) // Only collect non-nil errors and suppress them. + return nil + }) + + errOther := errors.New("something is wrong") + ctx, cancel := context.WithCancel(context.Background()) + var numOK int32 + for range numTasks { + g.Run(func() error { + d1 := randomDuration(2) + d2 := randomDuration(2) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(d1): + return errOther + case <-time.After(d2): + atomic.AddInt32(&numOK, 1) // Count successful executions. + return nil + } + }) + } + runtime.Gosched() + <-time.After(500 * time.Microsecond) + cancel() + + err := g.Wait() + require.NoError(t, err) // No captured error is expected, should be suppressed. + + var numCanceled, numOther int + for _, e := range errs { + switch { + case errors.Is(e, context.Canceled): + numCanceled++ + case errors.Is(e, errOther): + numOther++ + default: + require.FailNow(t, "unexpected error: %v", e) + } + } + + assert.NotZero(t, numOK) + assert.NotZero(t, numCanceled) + assert.NotZero(t, numOther) + total := int(numOK) + numCanceled + numOther + assert.Equal(t, numTasks, total) +} + +func TestWaitingForFinish(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx, cancel := context.WithCancel(context.Background()) + + failure := errors.New("failure") + exec := func() error { + select { + case <-ctx.Done(): + return work(50, nil)() + case <-time.After(randomDuration(60)): + return failure + } + } + + var g execution.TaskGroup + g.Run(exec) + g.Run(exec) + g.Run(exec) + + cancel() + + err := g.Wait() + require.NoError(t, err) +} + +func TestRegression(t *testing.T) { + t.Run("WaitRace", func(_ *testing.T) { + ready := make(chan struct{}) + var g execution.TaskGroup + g.Run(func() error { + <-ready + return nil + }) + + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); _ = g.Wait() }() + go func() { defer wg.Done(); _ = g.Wait() }() + + close(ready) + wg.Wait() + }) + t.Run("WaitUnstarted", func(t *testing.T) { + defer func() { + if x := recover(); x != nil { + t.Errorf("Unexpected panic: %v", x) + } + }() + var g execution.TaskGroup + _ = g.Wait() + }) +} + +func randomDuration(n int64) time.Duration { + return time.Duration(rand.Int64N(n)) * time.Millisecond +} + +// work returns an execution function that does nothing for random number of ms with [n] ms upper limit and returns err. +func work(n int64, err error) func() error { + return func() error { time.Sleep(randomDuration(n)); return err } +} From 561302add7fd343c86b93f0907f5fe2ca1244f51 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 15 Nov 2024 12:26:07 +0400 Subject: [PATCH 02/30] Networking package with a new connection handler Session added. --- .mockery.yaml | 12 + go.mod | 2 + go.sum | 4 + pkg/networking/address.go | 23 ++ pkg/networking/configuration.go | 59 ++++ pkg/networking/handler.go | 13 + pkg/networking/logger.go | 96 ++++++ pkg/networking/mocks/handler.go | 136 +++++++++ pkg/networking/mocks/header.go | 238 +++++++++++++++ pkg/networking/mocks/protocol.go | 278 ++++++++++++++++++ pkg/networking/network.go | 51 ++++ pkg/networking/protocol.go | 38 +++ pkg/networking/session.go | 404 +++++++++++++++++++++++++ pkg/networking/session_test.go | 487 +++++++++++++++++++++++++++++++ pkg/networking/timers.go | 41 +++ 15 files changed, 1882 insertions(+) create mode 100644 .mockery.yaml create mode 100644 pkg/networking/address.go create mode 100644 pkg/networking/configuration.go create mode 100644 pkg/networking/handler.go create mode 100644 pkg/networking/logger.go create mode 100644 pkg/networking/mocks/handler.go create mode 100644 pkg/networking/mocks/header.go create mode 100644 pkg/networking/mocks/protocol.go create mode 100644 pkg/networking/network.go create mode 100644 pkg/networking/protocol.go create mode 100644 pkg/networking/session.go create mode 100644 pkg/networking/session_test.go create mode 100644 pkg/networking/timers.go diff --git a/.mockery.yaml b/.mockery.yaml new file mode 100644 index 000000000..d7430c317 --- /dev/null +++ b/.mockery.yaml @@ -0,0 +1,12 @@ +quiet: False +with-expecter: True +dir: "{{.InterfaceDir}}/mocks" +mockname: "Mock{{.InterfaceName}}" +filename: "{{.InterfaceName | snakecase}}.go" + +packages: + github.com/wavesplatform/gowaves/pkg/networking: + interfaces: + Header: + Protocol: + Handler: diff --git a/go.mod b/go.mod index b946f41e2..5b7906d79 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/influxdata/influxdb1-client v0.0.0-20200827194710-b269163b24ab github.com/jinzhu/copier v0.4.0 github.com/mr-tron/base58 v1.2.0 + github.com/neilotoole/slogt v1.1.0 github.com/ory/dockertest/v3 v3.11.0 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pkg/errors v0.9.1 @@ -99,6 +100,7 @@ require ( github.com/rs/zerolog v1.33.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/gjson v1.14.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/go.sum b/go.sum index ec8bd95bd..c48b2feb8 100644 --- a/go.sum +++ b/go.sum @@ -191,6 +191,8 @@ github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjW github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/neilotoole/slogt v1.1.0 h1:c7qE92sq+V0yvCuaxph+RQ2jOKL61c4hqS1Bv9W7FZE= +github.com/neilotoole/slogt v1.1.0/go.mod h1:RCrGXkPc/hYybNulqQrMHRtvlQ7F6NktNVLuLwk6V+w= github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -274,6 +276,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/pkg/networking/address.go b/pkg/networking/address.go new file mode 100644 index 000000000..c1ebf9ca3 --- /dev/null +++ b/pkg/networking/address.go @@ -0,0 +1,23 @@ +package networking + +import ( + "fmt" + "net" +) + +type addressable interface { + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +type sessionAddress struct { + addr string +} + +func (*sessionAddress) Network() string { + return "session" +} + +func (a *sessionAddress) String() string { + return fmt.Sprintf("session:%s", a.addr) +} diff --git a/pkg/networking/configuration.go b/pkg/networking/configuration.go new file mode 100644 index 000000000..497c221b5 --- /dev/null +++ b/pkg/networking/configuration.go @@ -0,0 +1,59 @@ +package networking + +import ( + "log/slog" + "time" +) + +const ( + defaultKeepAliveInterval = 1 * time.Minute + defaultConnectionWriteTimeout = 15 * time.Second +) + +// Config allows to set some parameters of the [Conn] or it's underlying connection. +type Config struct { + logger Logger + protocol Protocol + handler Handler + keepAlive bool + keepAliveInterval time.Duration + connectionWriteTimeout time.Duration + attributes []any +} + +// NewConfig creates a new Config and sets required Protocol and Handler parameters. +// Other parameters are set to their default values. +func NewConfig(p Protocol, h Handler) *Config { + return &Config{ + logger: noopLogger{}, + protocol: p, + handler: h, + keepAlive: true, + keepAliveInterval: defaultKeepAliveInterval, + connectionWriteTimeout: defaultConnectionWriteTimeout, + attributes: nil, + } +} + +// WithLogger sets the logger. +func (c *Config) WithLogger(logger Logger) *Config { + c.logger = logger + return c +} + +// WithWriteTimeout sets connection write timeout attribute to the Config. +func (c *Config) WithWriteTimeout(timeout time.Duration) *Config { + c.connectionWriteTimeout = timeout + return c +} + +// WithSlogAttribute adds an attribute to the slice of attributes. +func (c *Config) WithSlogAttribute(attr slog.Attr) *Config { + c.attributes = append(c.attributes, attr) + return c +} + +func (c *Config) WithKeepAliveDisabled() *Config { + c.keepAlive = false + return c +} diff --git a/pkg/networking/handler.go b/pkg/networking/handler.go new file mode 100644 index 000000000..2f4f62587 --- /dev/null +++ b/pkg/networking/handler.go @@ -0,0 +1,13 @@ +package networking + +// Handler is an interface for handling new messages, handshakes and session close events. +type Handler interface { + // OnReceive fired on new message received. + OnReceive(*Session, []byte) + + // OnHandshake fired on new Handshake received. + OnHandshake(*Session, Handshake) + + // OnClose fired on Session closed. + OnClose(*Session) +} diff --git a/pkg/networking/logger.go b/pkg/networking/logger.go new file mode 100644 index 000000000..c518b27e0 --- /dev/null +++ b/pkg/networking/logger.go @@ -0,0 +1,96 @@ +package networking + +import ( + "context" +) + +const Namespace = "NET" + +type Logger interface { + // Debug logs a message at the debug level. + Debug(msg string, args ...any) + + // DebugContext logs a message at the debug level with access to the context's values + DebugContext(ctx context.Context, msg string, args ...any) + + // Info logs a message at the info level. + Info(msg string, args ...any) + + // InfoContext logs a message at the info level with access to the context's values + InfoContext(ctx context.Context, msg string, args ...any) + + // Warn logs a message at the warn level. + Warn(msg string, args ...any) + + // WarnContext logs a message at the warn level with access to the context's values + WarnContext(ctx context.Context, msg string, args ...any) + + // Error logs a message at the error level. + Error(msg string, args ...any) + + // ErrorContext logs a message at the error level with access to the context's values + ErrorContext(ctx context.Context, msg string, args ...any) +} + +type wrappingLogger struct { + logger Logger + attributes []any +} + +func (l *wrappingLogger) Debug(msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.Debug(msg, args...) +} + +func (l *wrappingLogger) DebugContext(ctx context.Context, msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.DebugContext(ctx, msg, args...) +} + +func (l *wrappingLogger) Info(msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.Info(msg, args...) +} + +func (l *wrappingLogger) InfoContext(ctx context.Context, msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.InfoContext(ctx, msg, args...) +} + +func (l *wrappingLogger) Warn(msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.Warn(msg, args...) +} + +func (l *wrappingLogger) WarnContext(ctx context.Context, msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.WarnContext(ctx, msg, args...) +} + +func (l *wrappingLogger) Error(msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.Error(msg, args...) +} + +func (l *wrappingLogger) ErrorContext(ctx context.Context, msg string, args ...any) { + args = append(args, l.attributes...) + l.logger.ErrorContext(ctx, msg, args...) +} + +type noopLogger struct{} + +func (noopLogger) Debug(string, ...any) {} + +func (noopLogger) DebugContext(context.Context, string, ...any) {} + +func (noopLogger) Info(string, ...any) {} + +func (noopLogger) InfoContext(context.Context, string, ...any) {} + +func (noopLogger) Warn(string, ...any) {} + +func (noopLogger) WarnContext(context.Context, string, ...any) {} + +func (noopLogger) Error(string, ...any) {} + +func (noopLogger) ErrorContext(context.Context, string, ...any) {} diff --git a/pkg/networking/mocks/handler.go b/pkg/networking/mocks/handler.go new file mode 100644 index 000000000..d7ba29dd3 --- /dev/null +++ b/pkg/networking/mocks/handler.go @@ -0,0 +1,136 @@ +// Code generated by mockery v2.46.3. DO NOT EDIT. + +package networking + +import ( + mock "github.com/stretchr/testify/mock" + networking "github.com/wavesplatform/gowaves/pkg/networking" +) + +// MockHandler is an autogenerated mock type for the Handler type +type MockHandler struct { + mock.Mock +} + +type MockHandler_Expecter struct { + mock *mock.Mock +} + +func (_m *MockHandler) EXPECT() *MockHandler_Expecter { + return &MockHandler_Expecter{mock: &_m.Mock} +} + +// OnClose provides a mock function with given fields: _a0 +func (_m *MockHandler) OnClose(_a0 *networking.Session) { + _m.Called(_a0) +} + +// MockHandler_OnClose_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnClose' +type MockHandler_OnClose_Call struct { + *mock.Call +} + +// OnClose is a helper method to define mock.On call +// - _a0 *networking.Session +func (_e *MockHandler_Expecter) OnClose(_a0 interface{}) *MockHandler_OnClose_Call { + return &MockHandler_OnClose_Call{Call: _e.mock.On("OnClose", _a0)} +} + +func (_c *MockHandler_OnClose_Call) Run(run func(_a0 *networking.Session)) *MockHandler_OnClose_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*networking.Session)) + }) + return _c +} + +func (_c *MockHandler_OnClose_Call) Return() *MockHandler_OnClose_Call { + _c.Call.Return() + return _c +} + +func (_c *MockHandler_OnClose_Call) RunAndReturn(run func(*networking.Session)) *MockHandler_OnClose_Call { + _c.Call.Return(run) + return _c +} + +// OnHandshake provides a mock function with given fields: _a0, _a1 +func (_m *MockHandler) OnHandshake(_a0 *networking.Session, _a1 networking.Handshake) { + _m.Called(_a0, _a1) +} + +// MockHandler_OnHandshake_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnHandshake' +type MockHandler_OnHandshake_Call struct { + *mock.Call +} + +// OnHandshake is a helper method to define mock.On call +// - _a0 *networking.Session +// - _a1 networking.Handshake +func (_e *MockHandler_Expecter) OnHandshake(_a0 interface{}, _a1 interface{}) *MockHandler_OnHandshake_Call { + return &MockHandler_OnHandshake_Call{Call: _e.mock.On("OnHandshake", _a0, _a1)} +} + +func (_c *MockHandler_OnHandshake_Call) Run(run func(_a0 *networking.Session, _a1 networking.Handshake)) *MockHandler_OnHandshake_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*networking.Session), args[1].(networking.Handshake)) + }) + return _c +} + +func (_c *MockHandler_OnHandshake_Call) Return() *MockHandler_OnHandshake_Call { + _c.Call.Return() + return _c +} + +func (_c *MockHandler_OnHandshake_Call) RunAndReturn(run func(*networking.Session, networking.Handshake)) *MockHandler_OnHandshake_Call { + _c.Call.Return(run) + return _c +} + +// OnReceive provides a mock function with given fields: _a0, _a1 +func (_m *MockHandler) OnReceive(_a0 *networking.Session, _a1 []byte) { + _m.Called(_a0, _a1) +} + +// MockHandler_OnReceive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnReceive' +type MockHandler_OnReceive_Call struct { + *mock.Call +} + +// OnReceive is a helper method to define mock.On call +// - _a0 *networking.Session +// - _a1 []byte +func (_e *MockHandler_Expecter) OnReceive(_a0 interface{}, _a1 interface{}) *MockHandler_OnReceive_Call { + return &MockHandler_OnReceive_Call{Call: _e.mock.On("OnReceive", _a0, _a1)} +} + +func (_c *MockHandler_OnReceive_Call) Run(run func(_a0 *networking.Session, _a1 []byte)) *MockHandler_OnReceive_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*networking.Session), args[1].([]byte)) + }) + return _c +} + +func (_c *MockHandler_OnReceive_Call) Return() *MockHandler_OnReceive_Call { + _c.Call.Return() + return _c +} + +func (_c *MockHandler_OnReceive_Call) RunAndReturn(run func(*networking.Session, []byte)) *MockHandler_OnReceive_Call { + _c.Call.Return(run) + return _c +} + +// NewMockHandler creates a new instance of MockHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *MockHandler { + mock := &MockHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/networking/mocks/header.go b/pkg/networking/mocks/header.go new file mode 100644 index 000000000..1de986214 --- /dev/null +++ b/pkg/networking/mocks/header.go @@ -0,0 +1,238 @@ +// Code generated by mockery v2.46.3. DO NOT EDIT. + +package networking + +import ( + io "io" + + mock "github.com/stretchr/testify/mock" +) + +// MockHeader is an autogenerated mock type for the Header type +type MockHeader struct { + mock.Mock +} + +type MockHeader_Expecter struct { + mock *mock.Mock +} + +func (_m *MockHeader) EXPECT() *MockHeader_Expecter { + return &MockHeader_Expecter{mock: &_m.Mock} +} + +// HeaderLength provides a mock function with given fields: +func (_m *MockHeader) HeaderLength() uint32 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for HeaderLength") + } + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// MockHeader_HeaderLength_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HeaderLength' +type MockHeader_HeaderLength_Call struct { + *mock.Call +} + +// HeaderLength is a helper method to define mock.On call +func (_e *MockHeader_Expecter) HeaderLength() *MockHeader_HeaderLength_Call { + return &MockHeader_HeaderLength_Call{Call: _e.mock.On("HeaderLength")} +} + +func (_c *MockHeader_HeaderLength_Call) Run(run func()) *MockHeader_HeaderLength_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockHeader_HeaderLength_Call) Return(_a0 uint32) *MockHeader_HeaderLength_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockHeader_HeaderLength_Call) RunAndReturn(run func() uint32) *MockHeader_HeaderLength_Call { + _c.Call.Return(run) + return _c +} + +// PayloadLength provides a mock function with given fields: +func (_m *MockHeader) PayloadLength() uint32 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for PayloadLength") + } + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + +// MockHeader_PayloadLength_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PayloadLength' +type MockHeader_PayloadLength_Call struct { + *mock.Call +} + +// PayloadLength is a helper method to define mock.On call +func (_e *MockHeader_Expecter) PayloadLength() *MockHeader_PayloadLength_Call { + return &MockHeader_PayloadLength_Call{Call: _e.mock.On("PayloadLength")} +} + +func (_c *MockHeader_PayloadLength_Call) Run(run func()) *MockHeader_PayloadLength_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockHeader_PayloadLength_Call) Return(_a0 uint32) *MockHeader_PayloadLength_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockHeader_PayloadLength_Call) RunAndReturn(run func() uint32) *MockHeader_PayloadLength_Call { + _c.Call.Return(run) + return _c +} + +// ReadFrom provides a mock function with given fields: r +func (_m *MockHeader) ReadFrom(r io.Reader) (int64, error) { + ret := _m.Called(r) + + if len(ret) == 0 { + panic("no return value specified for ReadFrom") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(io.Reader) (int64, error)); ok { + return rf(r) + } + if rf, ok := ret.Get(0).(func(io.Reader) int64); ok { + r0 = rf(r) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(io.Reader) error); ok { + r1 = rf(r) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHeader_ReadFrom_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadFrom' +type MockHeader_ReadFrom_Call struct { + *mock.Call +} + +// ReadFrom is a helper method to define mock.On call +// - r io.Reader +func (_e *MockHeader_Expecter) ReadFrom(r interface{}) *MockHeader_ReadFrom_Call { + return &MockHeader_ReadFrom_Call{Call: _e.mock.On("ReadFrom", r)} +} + +func (_c *MockHeader_ReadFrom_Call) Run(run func(r io.Reader)) *MockHeader_ReadFrom_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(io.Reader)) + }) + return _c +} + +func (_c *MockHeader_ReadFrom_Call) Return(n int64, err error) *MockHeader_ReadFrom_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockHeader_ReadFrom_Call) RunAndReturn(run func(io.Reader) (int64, error)) *MockHeader_ReadFrom_Call { + _c.Call.Return(run) + return _c +} + +// WriteTo provides a mock function with given fields: w +func (_m *MockHeader) WriteTo(w io.Writer) (int64, error) { + ret := _m.Called(w) + + if len(ret) == 0 { + panic("no return value specified for WriteTo") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(io.Writer) (int64, error)); ok { + return rf(w) + } + if rf, ok := ret.Get(0).(func(io.Writer) int64); ok { + r0 = rf(w) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(io.Writer) error); ok { + r1 = rf(w) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHeader_WriteTo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WriteTo' +type MockHeader_WriteTo_Call struct { + *mock.Call +} + +// WriteTo is a helper method to define mock.On call +// - w io.Writer +func (_e *MockHeader_Expecter) WriteTo(w interface{}) *MockHeader_WriteTo_Call { + return &MockHeader_WriteTo_Call{Call: _e.mock.On("WriteTo", w)} +} + +func (_c *MockHeader_WriteTo_Call) Run(run func(w io.Writer)) *MockHeader_WriteTo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(io.Writer)) + }) + return _c +} + +func (_c *MockHeader_WriteTo_Call) Return(n int64, err error) *MockHeader_WriteTo_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockHeader_WriteTo_Call) RunAndReturn(run func(io.Writer) (int64, error)) *MockHeader_WriteTo_Call { + _c.Call.Return(run) + return _c +} + +// NewMockHeader creates a new instance of MockHeader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockHeader(t interface { + mock.TestingT + Cleanup(func()) +}) *MockHeader { + mock := &MockHeader{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/networking/mocks/protocol.go b/pkg/networking/mocks/protocol.go new file mode 100644 index 000000000..19afa9cff --- /dev/null +++ b/pkg/networking/mocks/protocol.go @@ -0,0 +1,278 @@ +// Code generated by mockery v2.46.3. DO NOT EDIT. + +package networking + +import ( + mock "github.com/stretchr/testify/mock" + networking "github.com/wavesplatform/gowaves/pkg/networking" +) + +// MockProtocol is an autogenerated mock type for the Protocol type +type MockProtocol struct { + mock.Mock +} + +type MockProtocol_Expecter struct { + mock *mock.Mock +} + +func (_m *MockProtocol) EXPECT() *MockProtocol_Expecter { + return &MockProtocol_Expecter{mock: &_m.Mock} +} + +// EmptyHandshake provides a mock function with given fields: +func (_m *MockProtocol) EmptyHandshake() networking.Handshake { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for EmptyHandshake") + } + + var r0 networking.Handshake + if rf, ok := ret.Get(0).(func() networking.Handshake); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(networking.Handshake) + } + } + + return r0 +} + +// MockProtocol_EmptyHandshake_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmptyHandshake' +type MockProtocol_EmptyHandshake_Call struct { + *mock.Call +} + +// EmptyHandshake is a helper method to define mock.On call +func (_e *MockProtocol_Expecter) EmptyHandshake() *MockProtocol_EmptyHandshake_Call { + return &MockProtocol_EmptyHandshake_Call{Call: _e.mock.On("EmptyHandshake")} +} + +func (_c *MockProtocol_EmptyHandshake_Call) Run(run func()) *MockProtocol_EmptyHandshake_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProtocol_EmptyHandshake_Call) Return(_a0 networking.Handshake) *MockProtocol_EmptyHandshake_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_EmptyHandshake_Call) RunAndReturn(run func() networking.Handshake) *MockProtocol_EmptyHandshake_Call { + _c.Call.Return(run) + return _c +} + +// EmptyHeader provides a mock function with given fields: +func (_m *MockProtocol) EmptyHeader() networking.Header { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for EmptyHeader") + } + + var r0 networking.Header + if rf, ok := ret.Get(0).(func() networking.Header); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(networking.Header) + } + } + + return r0 +} + +// MockProtocol_EmptyHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmptyHeader' +type MockProtocol_EmptyHeader_Call struct { + *mock.Call +} + +// EmptyHeader is a helper method to define mock.On call +func (_e *MockProtocol_Expecter) EmptyHeader() *MockProtocol_EmptyHeader_Call { + return &MockProtocol_EmptyHeader_Call{Call: _e.mock.On("EmptyHeader")} +} + +func (_c *MockProtocol_EmptyHeader_Call) Run(run func()) *MockProtocol_EmptyHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProtocol_EmptyHeader_Call) Return(_a0 networking.Header) *MockProtocol_EmptyHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_EmptyHeader_Call) RunAndReturn(run func() networking.Header) *MockProtocol_EmptyHeader_Call { + _c.Call.Return(run) + return _c +} + +// IsAcceptableHandshake provides a mock function with given fields: _a0 +func (_m *MockProtocol) IsAcceptableHandshake(_a0 networking.Handshake) bool { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for IsAcceptableHandshake") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(networking.Handshake) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockProtocol_IsAcceptableHandshake_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsAcceptableHandshake' +type MockProtocol_IsAcceptableHandshake_Call struct { + *mock.Call +} + +// IsAcceptableHandshake is a helper method to define mock.On call +// - _a0 networking.Handshake +func (_e *MockProtocol_Expecter) IsAcceptableHandshake(_a0 interface{}) *MockProtocol_IsAcceptableHandshake_Call { + return &MockProtocol_IsAcceptableHandshake_Call{Call: _e.mock.On("IsAcceptableHandshake", _a0)} +} + +func (_c *MockProtocol_IsAcceptableHandshake_Call) Run(run func(_a0 networking.Handshake)) *MockProtocol_IsAcceptableHandshake_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(networking.Handshake)) + }) + return _c +} + +func (_c *MockProtocol_IsAcceptableHandshake_Call) Return(_a0 bool) *MockProtocol_IsAcceptableHandshake_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_IsAcceptableHandshake_Call) RunAndReturn(run func(networking.Handshake) bool) *MockProtocol_IsAcceptableHandshake_Call { + _c.Call.Return(run) + return _c +} + +// IsAcceptableMessage provides a mock function with given fields: _a0 +func (_m *MockProtocol) IsAcceptableMessage(_a0 networking.Header) bool { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for IsAcceptableMessage") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(networking.Header) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockProtocol_IsAcceptableMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsAcceptableMessage' +type MockProtocol_IsAcceptableMessage_Call struct { + *mock.Call +} + +// IsAcceptableMessage is a helper method to define mock.On call +// - _a0 networking.Header +func (_e *MockProtocol_Expecter) IsAcceptableMessage(_a0 interface{}) *MockProtocol_IsAcceptableMessage_Call { + return &MockProtocol_IsAcceptableMessage_Call{Call: _e.mock.On("IsAcceptableMessage", _a0)} +} + +func (_c *MockProtocol_IsAcceptableMessage_Call) Run(run func(_a0 networking.Header)) *MockProtocol_IsAcceptableMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(networking.Header)) + }) + return _c +} + +func (_c *MockProtocol_IsAcceptableMessage_Call) Return(_a0 bool) *MockProtocol_IsAcceptableMessage_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProtocol_IsAcceptableMessage_Call) RunAndReturn(run func(networking.Header) bool) *MockProtocol_IsAcceptableMessage_Call { + _c.Call.Return(run) + return _c +} + +// Ping provides a mock function with given fields: +func (_m *MockProtocol) Ping() ([]byte, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func() ([]byte, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProtocol_Ping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ping' +type MockProtocol_Ping_Call struct { + *mock.Call +} + +// Ping is a helper method to define mock.On call +func (_e *MockProtocol_Expecter) Ping() *MockProtocol_Ping_Call { + return &MockProtocol_Ping_Call{Call: _e.mock.On("Ping")} +} + +func (_c *MockProtocol_Ping_Call) Run(run func()) *MockProtocol_Ping_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProtocol_Ping_Call) Return(_a0 []byte, _a1 error) *MockProtocol_Ping_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProtocol_Ping_Call) RunAndReturn(run func() ([]byte, error)) *MockProtocol_Ping_Call { + _c.Call.Return(run) + return _c +} + +// NewMockProtocol creates a new instance of MockProtocol. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockProtocol(t interface { + mock.TestingT + Cleanup(func()) +}) *MockProtocol { + mock := &MockProtocol{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/networking/network.go b/pkg/networking/network.go new file mode 100644 index 000000000..f0dabec3f --- /dev/null +++ b/pkg/networking/network.go @@ -0,0 +1,51 @@ +package networking + +import ( + "context" + "errors" + "fmt" + "io" +) + +var ( + // ErrInvalidConfigurationNoProtocol is used when the configuration has no protocol. + ErrInvalidConfigurationNoProtocol = errors.New("invalid configuration: empty protocol") + + // ErrInvalidConfigurationNoHandler is used when the configuration has no handler. + ErrInvalidConfigurationNoHandler = errors.New("invalid configuration: empty handler") + + // ErrUnacceptableHandshake is used when the handshake is not accepted. + ErrUnacceptableHandshake = errors.New("handshake is not accepted") + + // ErrSessionShutdown is used if there is a shutdown during an operation. + ErrSessionShutdown = errors.New("session shutdown") + + // ErrConnectionWriteTimeout indicates that we hit the timeout writing to the underlying stream connection. + ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout") + + // ErrKeepAliveProtocolFailure is used when the protocol failed to provide a keep-alive message. + ErrKeepAliveProtocolFailure = errors.New("protocol failed to provide a keep-alive message") + + // ErrConnectionClosedOnRead indicates that the connection was closed while reading. + ErrConnectionClosedOnRead = errors.New("connection closed on read") + + // ErrKeepAliveTimeout indicates that we failed to send keep-alive message and abandon a keep-alive loop. + ErrKeepAliveTimeout = errors.New("keep-alive loop timeout") + + // ErrEmptyTimerPool is raised on creation of Session with a nil pool. + ErrEmptyTimerPool = errors.New("empty timer pool") +) + +type Network struct { + tp *timerPool +} + +func NewNetwork() *Network { + return &Network{ + tp: newTimerPool(), + } +} + +func (n *Network) NewSession(ctx context.Context, conn io.ReadWriteCloser, conf *Config) (*Session, error) { + return newSession(ctx, conf, conn, n.tp) +} diff --git a/pkg/networking/protocol.go b/pkg/networking/protocol.go new file mode 100644 index 000000000..a5b97ec25 --- /dev/null +++ b/pkg/networking/protocol.go @@ -0,0 +1,38 @@ +package networking + +import "io" + +// Header is the interface that should be implemented by the real message header packet. +type Header interface { + io.ReaderFrom + io.WriterTo + HeaderLength() uint32 + PayloadLength() uint32 +} + +// Handshake is the common interface for a handshake packet. +type Handshake interface { + io.ReaderFrom + io.WriterTo +} + +// Protocol is the interface for the network protocol implementation. +// It provides the methods to create the handshake packet, message header, and ping packet. +// It also provides the methods to validate the handshake and message header packets. +type Protocol interface { + // EmptyHandshake returns the empty instance of the handshake packet. + EmptyHandshake() Handshake + + // EmptyHeader returns the empty instance of the message header. + EmptyHeader() Header + + // Ping return the actual ping packet. + Ping() ([]byte, error) + + // IsAcceptableHandshake checks the handshake is acceptable. + IsAcceptableHandshake(Handshake) bool + + // IsAcceptableMessage checks the message is acceptable by examining its header. + // If return false, the message will be discarded. + IsAcceptableMessage(Header) bool +} diff --git a/pkg/networking/session.go b/pkg/networking/session.go new file mode 100644 index 000000000..97f8085a7 --- /dev/null +++ b/pkg/networking/session.go @@ -0,0 +1,404 @@ +package networking + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "errors" + "io" + "log/slog" + "net" + "strings" + "sync" + "time" + + "github.com/wavesplatform/gowaves/pkg/execution" +) + +// Session is used to wrap a reliable ordered connection. +type Session struct { + g *execution.TaskGroup + ctx context.Context + cancel context.CancelFunc + + config *Config + logger Logger + tp *timerPool + + conn io.ReadWriteCloser // conn is the underlying connection + bufRead *bufio.Reader // buffered reader wrapped around the connection + + receiveLock sync.Mutex // Guards the receiveBuffer. + receiveBuffer *bytes.Buffer // receiveBuffer is used to store the incoming data. + + sendLock sync.Mutex // Guards the sendCh. + sendCh chan *sendPacket // sendCh is used to send data to the connection. + + establishedLock sync.Mutex // Guards the established field. + established bool // Indicates that incoming Handshake was successfully accepted. + shutdownLock sync.Mutex // Guards the shutdown field. + shutdown bool // shutdown is used to safely close the Session. +} + +// NewSession is used to construct a new session. +func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp *timerPool) (*Session, error) { + if config.protocol == nil { + return nil, ErrInvalidConfigurationNoProtocol + } + if config.handler == nil { + return nil, ErrInvalidConfigurationNoHandler + } + if tp == nil { + return nil, ErrEmptyTimerPool + } + sCtx, cancel := context.WithCancel(ctx) + //TODO: Properly initialize sendCh + s := &Session{ + g: execution.NewTaskGroup(suppressContextCancellationError), + ctx: sCtx, + cancel: cancel, + config: config, + tp: tp, + conn: conn, + bufRead: bufio.NewReader(conn), + sendCh: make(chan *sendPacket, 1), + } + + attributes := []any{ + slog.String("namespace", Namespace), + slog.String("remote", s.RemoteAddr().String()), + } + attributes = append(attributes, config.attributes...) + + if config.logger == nil { + s.logger = noopLogger{} + } else { + s.logger = &wrappingLogger{ + logger: config.logger, + attributes: attributes, + } + } + + s.g.Run(s.receiveLoop) + s.g.Run(s.sendLoop) + if s.config.keepAlive { + s.g.Run(s.keepaliveLoop) + } + + return s, nil +} + +// LocalAddr returns the local network address. +func (s *Session) LocalAddr() net.Addr { + if a, ok := s.conn.(addressable); ok { + return a.LocalAddr() + } + return &sessionAddress{addr: "local"} +} + +// RemoteAddr returns the remote network address. +func (s *Session) RemoteAddr() net.Addr { + if a, ok := s.conn.(addressable); ok { + return a.RemoteAddr() + } + return &sessionAddress{addr: "remote"} +} + +// Close is used to close the session. It is safe to call Close multiple times from different goroutines, +// subsequent calls do nothing. +func (s *Session) Close() error { + s.shutdownLock.Lock() + defer s.shutdownLock.Unlock() + + if s.shutdown { + return nil // Fast path - session already closed. + } + s.shutdown = true + + s.logger.Debug("Closing session") + clErr := s.conn.Close() // Close the underlying connection. + if clErr != nil { + s.logger.Warn("Failed to close underlying connection", "error", clErr) + } + s.logger.Debug("Underlying connection closed") + + s.cancel() // Cancel the underlying context to interrupt the loops. + + s.logger.Debug("Waiting for loops to finish") + err := s.g.Wait() // Wait for loops to finish. + + err = errors.Join(err, clErr) // Combine loops finalization errors with connection close error. + + s.logger.Debug("Session closed", "error", err) + return err +} + +// Write is used to write to the session. It is safe to call Write and/or Close concurrently. +func (s *Session) Write(msg []byte) (int, error) { + s.sendLock.Lock() + defer s.sendLock.Unlock() + + if err := s.waitForSend(msg); err != nil { + return 0, err + } + + return len(msg), nil +} + +// waitForSend waits to send a data, checking for a potential context cancellation. +func (s *Session) waitForSend(data []byte) error { + // Channel to receive an error from sendLoop goroutine. + // We are not closing this channel, it will be GCed when the session is closed. + errCh := make(chan error, 1) + + timer := s.tp.Get() + timer.Reset(s.config.connectionWriteTimeout) + defer s.tp.Put(timer) + + s.logger.Debug("Sending data", "data", base64.StdEncoding.EncodeToString(data)) + ready := &sendPacket{data: data, err: errCh} + select { + case s.sendCh <- ready: + s.logger.Debug("Data written into send channel") + case <-s.ctx.Done(): + s.logger.Debug("Session shutdown while sending data") + return ErrSessionShutdown + case <-timer.C: + s.logger.Debug("Connection write timeout while sending data") + return ErrConnectionWriteTimeout + } + + dataCopy := func() { + if data == nil { + return // A nil data is ignored. + } + + // In the event of session shutdown or connection write timeout, we need to prevent `send` from reading + // the body buffer after returning from this function since the caller may re-use the underlying array. + ready.mu.Lock() + defer ready.mu.Unlock() + + if ready.data == nil { + return // data was already copied in `send`. + } + newData := make([]byte, len(data)) + copy(newData, data) + ready.data = newData + } + + select { + case err := <-errCh: + s.logger.Debug("Data sent", "error", err) + return err + case <-s.ctx.Done(): + dataCopy() + s.logger.Debug("Session shutdown while waiting send error") + return ErrSessionShutdown + case <-timer.C: + dataCopy() + s.logger.Debug("Connection write timeout while waiting send error") + return ErrConnectionWriteTimeout + } +} + +// sendLoop is a long-running goroutine that sends data to the connection. +func (s *Session) sendLoop() error { + var dataBuf bytes.Buffer + for { + dataBuf.Reset() + + select { + case <-s.ctx.Done(): + s.logger.Debug("Exiting connection send loop") + return s.ctx.Err() + + case packet := <-s.sendCh: + s.logger.Debug("Sending data to connection", + "data", base64.StdEncoding.EncodeToString(packet.data)) + packet.mu.Lock() + if packet.data != nil { + // Copy the data into the buffer to avoid holding a mutex lock during the writing. + _, err := dataBuf.Write(packet.data) + if err != nil { + packet.data = nil + packet.mu.Unlock() + s.logger.Error("Failed to copy data into buffer", "error", err) + s.asyncSendErr(packet.err, err) + return err // TODO: Do we need to return here? + } + s.logger.Debug("Data copied into buffer") + packet.data = nil + } + packet.mu.Unlock() + + if dataBuf.Len() > 0 { + s.logger.Debug("Writing data into connection", "len", len(dataBuf.Bytes())) + _, err := s.conn.Write(dataBuf.Bytes()) // TODO: We are locking here, because no timeout set on connection itself. + if err != nil { + s.logger.Error("Failed to write data into connection", "error", err) + s.asyncSendErr(packet.err, err) + return err + } + s.logger.Debug("Data written into connection") + } + + // No error, successful send + s.asyncSendErr(packet.err, nil) + } + } +} + +// receiveLoop continues to receive data until a fatal error is encountered or underlying connection is closed. +// Receive loop works after handshake and accepts only length-prepended messages. +func (s *Session) receiveLoop() error { + s.establishedLock.Lock() // Prevents from running multiple receiveLoops. + defer s.establishedLock.Unlock() + + for { + if err := s.receive(); err != nil { + if errors.Is(err, ErrConnectionClosedOnRead) { + s.config.handler.OnClose(s) + return nil // Exit normally on connection close. + } + return err + } + } +} + +func (s *Session) receive() error { + if s.established { + hdr := s.config.protocol.EmptyHeader() + return s.readMessage(hdr) + } + return s.readHandshake() +} + +func (s *Session) readHandshake() error { + s.logger.Debug("Reading handshake") + + hs := s.config.protocol.EmptyHandshake() + _, err := hs.ReadFrom(s.bufRead) + if err != nil { + if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "closed") || + strings.Contains(err.Error(), "reset by peer") { + return ErrConnectionClosedOnRead + } + s.logger.Error("Failed to read handshake from connection", "error", err) + return err + } + s.logger.Debug("Handshake successfully read") + + if !s.config.protocol.IsAcceptableHandshake(hs) { + s.logger.Error("Handshake is not acceptable") + return ErrUnacceptableHandshake + } + // Handshake is acceptable, we can switch the session into established state. + s.established = true + s.config.handler.OnHandshake(s, hs) + return nil +} + +func (s *Session) readMessage(hdr Header) error { + // Read the header + if _, err := hdr.ReadFrom(s.bufRead); err != nil { + if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "closed") || + strings.Contains(err.Error(), "reset by peer") { + return ErrConnectionClosedOnRead + } + s.logger.Error("Failed to read header", "error", err) + return err + } + if !s.config.protocol.IsAcceptableMessage(hdr) { + // We have to discard the remaining part of the message. + if _, err := io.CopyN(io.Discard, s.bufRead, int64(hdr.PayloadLength())); err != nil { + s.logger.Error("Failed to discard message", "error", err) + return err + } + } + // Read the new data + if err := s.readMessagePayload(hdr, s.bufRead); err != nil { + s.logger.Error("Failed to read message", "error", err) + return err + } + return nil +} + +func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { + // Wrap in a limited reader + conn = &io.LimitedReader{R: conn, N: int64(hdr.PayloadLength())} + + // Copy into buffer + s.receiveLock.Lock() + defer s.receiveLock.Unlock() + + if s.receiveBuffer == nil { + // Allocate the receiving buffer just-in-time to fit the full message. + s.receiveBuffer = bytes.NewBuffer(make([]byte, 0, hdr.HeaderLength()+hdr.PayloadLength())) + } + _, err := hdr.WriteTo(s.receiveBuffer) + if err != nil { + s.logger.Error("Failed to write header to receiving buffer", "error", err) + return err + } + _, err = io.Copy(s.receiveBuffer, conn) + if err != nil { + s.logger.Error("Failed to copy payload to receiving buffer", "error", err) + return err + } + // We lock the buffer from modification on the time of invocation of OnReceive handler. + // The slice of bytes passed into the handler is only valid for the duration of the handler invocation. + // So inside the handler better deserialize message or make a copy of the bytes. + s.config.handler.OnReceive(s, s.receiveBuffer.Bytes()) // Invoke OnReceive handler. + return nil +} + +// keepaliveLoop is a long-running goroutine that periodically sends a Ping message to keep the connection alive. +func (s *Session) keepaliveLoop() error { + for { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + case <-time.After(s.config.keepAliveInterval): + // Get actual Ping message from Protocol. + p, err := s.config.protocol.Ping() + if err != nil { + s.logger.Error("Failed to get ping message", "error", err) + return ErrKeepAliveProtocolFailure + } + if sndErr := s.waitForSend(p); sndErr != nil { + if errors.Is(sndErr, ErrSessionShutdown) { + return nil // Exit normally on session termination. + } + s.logger.Error("Failed to send ping message", "error", err) + return ErrKeepAliveTimeout + } + } + } +} + +// sendPacket is used to send data. +type sendPacket struct { + mu sync.Mutex // Protects data from unsafe reads. + data []byte + err chan error +} + +// asyncSendErr is used to try an async send of an error. +func (s *Session) asyncSendErr(ch chan error, err error) { + if ch == nil { + return + } + select { + case ch <- err: + s.logger.Debug("Error sent to channel", "error", err) + default: + } +} + +func suppressContextCancellationError(err error) error { + if errors.Is(err, context.Canceled) { + return nil + } + return err +} diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go new file mode 100644 index 000000000..393a6ec90 --- /dev/null +++ b/pkg/networking/session_test.go @@ -0,0 +1,487 @@ +package networking_test + +import ( + "context" + "encoding/binary" + "errors" + "io" + "log/slog" + "sync" + "testing" + "time" + + "github.com/neilotoole/slogt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/wavesplatform/gowaves/pkg/networking" + netmocks "github.com/wavesplatform/gowaves/pkg/networking/mocks" +) + +func TestSuccessfulSession(t *testing.T) { + defer goleak.VerifyNone(t) + + p := netmocks.NewMockProtocol(t) + p.On("EmptyHandshake").Return(&textHandshake{}, nil) + p.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + p.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + p.On("EmptyHeader").Return(&textHeader{}, nil) + p.On("IsAcceptableMessage", &textHeader{l: 2}).Once().Return(true) + p.On("IsAcceptableMessage", &textHeader{l: 13}).Once().Return(true) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + cs, err := net.NewSession(ctx, clientConn, testConfig(t, p, clientHandler, "client")) + require.NoError(t, err) + ss, err := net.NewSession(ctx, serverConn, testConfig(t, p, serverHandler, "server")) + require.NoError(t, err) + + var sWG sync.WaitGroup + var cWG sync.WaitGroup + sWG.Add(1) + go func() { + sc1 := serverHandler.On("OnHandshake", ss, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + n, wErr := ss.Write([]byte("hello")) + require.NoError(t, wErr) + assert.Equal(t, 5, n) + }) + sc2 := serverHandler.On("OnReceive", ss, encodeMessage("Hello session")).Once().Return() + sc2.NotBefore(sc1). + Run(func(_ mock.Arguments) { + n, wErr := ss.Write(encodeMessage("Hi")) + require.NoError(t, wErr) + assert.Equal(t, 6, n) + sWG.Done() + }) + sWG.Wait() + }() + + cWG.Add(1) + cl1 := clientHandler.On("OnHandshake", cs, &textHandshake{v: "hello"}).Once().Return() + cl1.Run(func(_ mock.Arguments) { + n, wErr := cs.Write(encodeMessage("Hello session")) + require.NoError(t, wErr) + assert.Equal(t, 17, n) + }) + cl2 := clientHandler.On("OnReceive", cs, encodeMessage("Hi")).Once().Return() + cl2.NotBefore(cl1). + Run(func(_ mock.Arguments) { + cWG.Done() + }) + + n, err := cs.Write([]byte("hello")) // Send handshake to server. + require.NoError(t, err) + assert.Equal(t, 5, n) + + cWG.Wait() // Wait for server to finish. + + clientHandler.On("OnClose", cs).Return() + serverHandler.On("OnClose", ss).Return() + err = cs.Close() + assert.NoError(t, err) + err = ss.Close() + assert.NoError(t, err) +} + +func TestSessionTimeoutOnHandshake(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + + serverHandler.On("OnClose", serverSession).Return() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + // Close server. + err = serverSession.Close() + assert.NoError(t, err) + wg.Done() + }() + + // Lock + pc, ok := clientConn.(*pipeConn) + require.True(t, ok) + pc.writeBlocker.Lock() + + clientHandler.On("OnClose", clientSession).Return() + + // Send handshake to server, but writing will block because the clientConn is locked. + n, err := clientSession.Write([]byte("hello")) + require.Error(t, err) + assert.Equal(t, 0, n) + + time.Sleep(2 * time.Second) // Let timeout occur. + + // Unlock "timeout" and close client. + wg.Wait() + pc.writeBlocker.Unlock() + err = clientSession.Close() + assert.Error(t, err) +} + +func TestSessionTimeoutOnMessage(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("EmptyHeader").Return(&textHeader{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + pc, ok := clientConn.(*pipeConn) + require.True(t, ok) + + serverHandler.On("OnClose", serverSession).Return() + + var serverWG sync.WaitGroup + var clientWG sync.WaitGroup + serverWG.Add(1) + clientWG.Add(1) + go func() { + sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + n, wErr := serverSession.Write([]byte("hello")) + require.NoError(t, wErr) + assert.Equal(t, 5, n) + serverWG.Done() + }) + serverWG.Wait() // Wait for finishing handshake before closing the pipe. + + // Lock pipe after replying with the handshake from server. + pc.writeBlocker.Lock() + clientWG.Done() // Signal that pipe is locked. + }() + + serverHandler.On("OnClose", serverSession).Return() + clientHandler.On("OnClose", clientSession).Return() + + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + + cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() + cs1.Run(func(_ mock.Arguments) { + clientWG.Wait() // Wait for pipe to be locked. + // On receiving handshake from server, send the message back to server. + _, msgErr := clientSession.Write(encodeMessage("Hello session")) + require.Error(t, msgErr) + }) + + time.Sleep(1 * time.Second) // Let timeout occur. + + err = serverSession.Close() + assert.NoError(t, err) // Expect no error on the server side. + + pc.writeBlocker.Unlock() // Unlock the pipe. + + err = clientSession.Close() + assert.Error(t, err) // Expect error because connection to the server already closed. +} + +func TestDoubleClose(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + clientHandler.On("OnClose", clientSession).Return() + serverHandler.On("OnClose", serverSession).Return() + + err = clientSession.Close() + assert.NoError(t, err) + err = clientSession.Close() + assert.NoError(t, err) + + err = serverSession.Close() + assert.NoError(t, err) + err = serverSession.Close() + assert.NoError(t, err) +} + +func TestOnClosedByOtherSide(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("EmptyHeader").Return(&textHeader{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + var closeWG sync.WaitGroup + closeWG.Add(1) + + var wg sync.WaitGroup + wg.Add(2) + + serverHandler.On("OnClose", serverSession).Return() + sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + n, wErr := serverSession.Write([]byte("hello")) + assert.NoError(t, wErr) + assert.Equal(t, 5, n) + go func() { + // Close server after client received the handshake from server. + closeWG.Wait() // Wait for client to receive server handshake. + clErr := serverSession.Close() + assert.NoError(t, clErr) + wg.Done() + }() + }) + + clientHandler.On("OnClose", clientSession).Return() + + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + + cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() + cs1.Run(func(_ mock.Arguments) { + // On receiving handshake from server, signal to close the server. + closeWG.Done() + // Try to send message to server, but it will fail because server is already closed. + time.Sleep(10 * time.Millisecond) // Wait for server to close. + _, msgErr := clientSession.Write(encodeMessage("Hello session")) + require.Error(t, msgErr) + wg.Done() + }) + + wg.Wait() // Wait for client to finish. + err = clientSession.Close() + assert.Error(t, err) // Close reports the same error, because it was registered in the send loop. +} + +func TestCloseParentContext(t *testing.T) { + defer goleak.VerifyNone(t) + + mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("IsAcceptableHandshake", &textHandshake{v: "hello"}).Once().Return(true) + mockProtocol.On("EmptyHeader").Return(&textHeader{}, nil) + + clientHandler := netmocks.NewMockHandler(t) + serverHandler := netmocks.NewMockHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientConn, serverConn := testConnPipe() + net := networking.NewNetwork() + + clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) + require.NoError(t, err) + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) + require.NoError(t, err) + + var closeWG sync.WaitGroup + closeWG.Add(1) + + var wg sync.WaitGroup + wg.Add(2) + + serverHandler.On("OnClose", serverSession).Return() + sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + n, wErr := serverSession.Write([]byte("hello")) + assert.NoError(t, wErr) + assert.Equal(t, 5, n) + go func() { + closeWG.Wait() // Wait for client to receive server handshake. + cancel() // Close parent context. + wg.Done() + }() + }) + + clientHandler.On("OnClose", clientSession).Return() + + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + + cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() + cs1.Run(func(_ mock.Arguments) { + // On receiving handshake from server, signal to close the server. + closeWG.Done() + // Try to send message to server, but it will fail because server is already closed. + time.Sleep(10 * time.Millisecond) // Wait for server to close. + _, msgErr := clientSession.Write(encodeMessage("Hello session")) + require.Error(t, msgErr) + wg.Done() + }) + + wg.Wait() // Wait for client to finish. + + err = clientSession.Close() + assert.NoError(t, err) + err = serverSession.Close() + assert.NoError(t, err) +} + +func testConfig(t testing.TB, p networking.Protocol, h networking.Handler, direction string) *networking.Config { + log := slogt.New(t) + return networking.NewConfig(p, h). + WithLogger(log). + WithWriteTimeout(1 * time.Second). + WithKeepAliveDisabled(). + WithSlogAttribute(slog.String("direction", direction)) +} + +type pipeConn struct { + reader *io.PipeReader + writer *io.PipeWriter + writeBlocker sync.Mutex +} + +func (p *pipeConn) Read(b []byte) (int, error) { + return p.reader.Read(b) +} + +func (p *pipeConn) Write(b []byte) (int, error) { + p.writeBlocker.Lock() + defer p.writeBlocker.Unlock() + return p.writer.Write(b) +} + +func (p *pipeConn) Close() error { + rErr := p.reader.Close() + wErr := p.writer.Close() + return errors.Join(rErr, wErr) +} + +func testConnPipe() (io.ReadWriteCloser, io.ReadWriteCloser) { + read1, write1 := io.Pipe() + read2, write2 := io.Pipe() + conn1 := &pipeConn{reader: read1, writer: write2} + conn2 := &pipeConn{reader: read2, writer: write1} + return conn1, conn2 +} + +func encodeMessage(s string) []byte { + msg := make([]byte, 4+len(s)) + binary.BigEndian.PutUint32(msg[:4], uint32(len(s))) + copy(msg[4:], s) + return msg +} + +// We have to use the "real" handshake, not a mock, because we are reading or writing to a "real" piped connection. +type textHandshake struct { + v string +} + +func (h *textHandshake) ReadFrom(r io.Reader) (int64, error) { + buf := make([]byte, 5) + n, err := io.ReadFull(r, buf) + if err != nil { + return int64(n), err + } + h.v = string(buf[:n]) + return int64(n), nil +} + +func (h *textHandshake) WriteTo(w io.Writer) (int64, error) { + buf := []byte(h.v) + n, err := w.Write(buf) + return int64(n), err +} + +// We have to use the "real" header, not a mock, because we are reading or writing to a "real" piped connection. +type textHeader struct { + l uint32 +} + +func (h *textHeader) HeaderLength() uint32 { + return 4 +} + +func (h *textHeader) PayloadLength() uint32 { + return h.l +} + +func (h *textHeader) ReadFrom(r io.Reader) (int64, error) { + hdr := make([]byte, 4) + n, err := io.ReadFull(r, hdr) + if err != nil { + return int64(n), err + } + h.l = binary.BigEndian.Uint32(hdr) + return int64(n), nil +} + +func (h *textHeader) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, h.l) + n, err := w.Write(buf) + return int64(n), err +} diff --git a/pkg/networking/timers.go b/pkg/networking/timers.go new file mode 100644 index 000000000..5f2d2949f --- /dev/null +++ b/pkg/networking/timers.go @@ -0,0 +1,41 @@ +package networking + +import ( + "sync" + "time" +) + +const initialTimerInterval = time.Hour * 1e6 + +type timerPool struct { + p *sync.Pool +} + +func newTimerPool() *timerPool { + return &timerPool{ + p: &sync.Pool{ + New: func() any { + timer := time.NewTimer(initialTimerInterval) + timer.Stop() + return timer + }, + }, + } +} + +func (p *timerPool) Get() *time.Timer { + t, ok := p.p.Get().(*time.Timer) + if !ok { + panic("invalid type of item in TimerPool") + } + return t +} + +func (p *timerPool) Put(t *time.Timer) { + t.Stop() + select { + case <-t.C: + default: + } + p.p.Put(t) +} From 3d5e202e86767be8a2009167079a02e6e18ab950 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 18 Nov 2024 11:13:02 +0400 Subject: [PATCH 03/30] Logger interface removed from networking package. Standard slog package is used instead. --- pkg/networking/configuration.go | 6 +-- pkg/networking/logger.go | 96 --------------------------------- pkg/networking/network.go | 2 + pkg/networking/session.go | 12 ++--- pkg/networking/session_test.go | 21 +++----- 5 files changed, 16 insertions(+), 121 deletions(-) delete mode 100644 pkg/networking/logger.go diff --git a/pkg/networking/configuration.go b/pkg/networking/configuration.go index 497c221b5..46fb53f8b 100644 --- a/pkg/networking/configuration.go +++ b/pkg/networking/configuration.go @@ -12,7 +12,7 @@ const ( // Config allows to set some parameters of the [Conn] or it's underlying connection. type Config struct { - logger Logger + logger *slog.Logger protocol Protocol handler Handler keepAlive bool @@ -25,7 +25,7 @@ type Config struct { // Other parameters are set to their default values. func NewConfig(p Protocol, h Handler) *Config { return &Config{ - logger: noopLogger{}, + logger: slog.Default(), protocol: p, handler: h, keepAlive: true, @@ -36,7 +36,7 @@ func NewConfig(p Protocol, h Handler) *Config { } // WithLogger sets the logger. -func (c *Config) WithLogger(logger Logger) *Config { +func (c *Config) WithLogger(logger *slog.Logger) *Config { c.logger = logger return c } diff --git a/pkg/networking/logger.go b/pkg/networking/logger.go deleted file mode 100644 index c518b27e0..000000000 --- a/pkg/networking/logger.go +++ /dev/null @@ -1,96 +0,0 @@ -package networking - -import ( - "context" -) - -const Namespace = "NET" - -type Logger interface { - // Debug logs a message at the debug level. - Debug(msg string, args ...any) - - // DebugContext logs a message at the debug level with access to the context's values - DebugContext(ctx context.Context, msg string, args ...any) - - // Info logs a message at the info level. - Info(msg string, args ...any) - - // InfoContext logs a message at the info level with access to the context's values - InfoContext(ctx context.Context, msg string, args ...any) - - // Warn logs a message at the warn level. - Warn(msg string, args ...any) - - // WarnContext logs a message at the warn level with access to the context's values - WarnContext(ctx context.Context, msg string, args ...any) - - // Error logs a message at the error level. - Error(msg string, args ...any) - - // ErrorContext logs a message at the error level with access to the context's values - ErrorContext(ctx context.Context, msg string, args ...any) -} - -type wrappingLogger struct { - logger Logger - attributes []any -} - -func (l *wrappingLogger) Debug(msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.Debug(msg, args...) -} - -func (l *wrappingLogger) DebugContext(ctx context.Context, msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.DebugContext(ctx, msg, args...) -} - -func (l *wrappingLogger) Info(msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.Info(msg, args...) -} - -func (l *wrappingLogger) InfoContext(ctx context.Context, msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.InfoContext(ctx, msg, args...) -} - -func (l *wrappingLogger) Warn(msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.Warn(msg, args...) -} - -func (l *wrappingLogger) WarnContext(ctx context.Context, msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.WarnContext(ctx, msg, args...) -} - -func (l *wrappingLogger) Error(msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.Error(msg, args...) -} - -func (l *wrappingLogger) ErrorContext(ctx context.Context, msg string, args ...any) { - args = append(args, l.attributes...) - l.logger.ErrorContext(ctx, msg, args...) -} - -type noopLogger struct{} - -func (noopLogger) Debug(string, ...any) {} - -func (noopLogger) DebugContext(context.Context, string, ...any) {} - -func (noopLogger) Info(string, ...any) {} - -func (noopLogger) InfoContext(context.Context, string, ...any) {} - -func (noopLogger) Warn(string, ...any) {} - -func (noopLogger) WarnContext(context.Context, string, ...any) {} - -func (noopLogger) Error(string, ...any) {} - -func (noopLogger) ErrorContext(context.Context, string, ...any) {} diff --git a/pkg/networking/network.go b/pkg/networking/network.go index f0dabec3f..6715f24ed 100644 --- a/pkg/networking/network.go +++ b/pkg/networking/network.go @@ -7,6 +7,8 @@ import ( "io" ) +const Namespace = "NET" + var ( // ErrInvalidConfigurationNoProtocol is used when the configuration has no protocol. ErrInvalidConfigurationNoProtocol = errors.New("invalid configuration: empty protocol") diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 97f8085a7..189a59ce6 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -23,7 +23,7 @@ type Session struct { cancel context.CancelFunc config *Config - logger Logger + logger *slog.Logger tp *timerPool conn io.ReadWriteCloser // conn is the underlying connection @@ -53,7 +53,6 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp return nil, ErrEmptyTimerPool } sCtx, cancel := context.WithCancel(ctx) - //TODO: Properly initialize sendCh s := &Session{ g: execution.NewTaskGroup(suppressContextCancellationError), ctx: sCtx, @@ -62,7 +61,7 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp tp: tp, conn: conn, bufRead: bufio.NewReader(conn), - sendCh: make(chan *sendPacket, 1), + sendCh: make(chan *sendPacket, 1), // TODO: Make the size of send channel configurable. } attributes := []any{ @@ -72,12 +71,9 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp attributes = append(attributes, config.attributes...) if config.logger == nil { - s.logger = noopLogger{} + s.logger = slog.Default().With(attributes...) } else { - s.logger = &wrappingLogger{ - logger: config.logger, - attributes: attributes, - } + s.logger = config.logger.With(attributes...) } s.g.Run(s.receiveLoop) diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index 393a6ec90..51fbd8369 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "log/slog" + "runtime" "sync" "testing" "time" @@ -113,34 +114,26 @@ func TestSessionTimeoutOnHandshake(t *testing.T) { require.NoError(t, err) mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) - serverHandler.On("OnClose", serverSession).Return() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - // Close server. - err = serverSession.Close() - assert.NoError(t, err) - wg.Done() - }() + clientHandler.On("OnClose", clientSession).Return() // Lock pc, ok := clientConn.(*pipeConn) require.True(t, ok) pc.writeBlocker.Lock() - - clientHandler.On("OnClose", clientSession).Return() + runtime.Gosched() // Send handshake to server, but writing will block because the clientConn is locked. n, err := clientSession.Write([]byte("hello")) require.Error(t, err) assert.Equal(t, 0, n) - time.Sleep(2 * time.Second) // Let timeout occur. + runtime.Gosched() + + err = serverSession.Close() + assert.NoError(t, err) // Unlock "timeout" and close client. - wg.Wait() pc.writeBlocker.Unlock() err = clientSession.Close() assert.Error(t, err) From 60d0178058ae858e4ede8240cf2c31c784c80095 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 25 Nov 2024 12:26:36 +0400 Subject: [PATCH 04/30] WIP. Simple connection replaced with NetClient. NetClient usage moved into Universal client. Handshake proto updated to compatibility with Handshake interface from networking package. --- itests/clients/grpc_client.go | 5 + itests/clients/net_client.go | 212 +++++++++++++++++++++++++++++ itests/clients/node_client.go | 40 +++++- itests/clients/universal_client.go | 9 +- itests/fixtures/base_fixtures.go | 4 +- itests/net/connection.go | 148 -------------------- itests/utilities/common.go | 7 +- pkg/networking/session.go | 10 +- pkg/p2p/conn/conn.go | 6 +- pkg/proto/microblock.go | 21 +-- pkg/proto/proto.go | 126 ++++++++--------- pkg/ride/math/math_test.go | 1 + pkg/util/common/util.go | 11 ++ 13 files changed, 364 insertions(+), 236 deletions(-) create mode 100644 itests/clients/net_client.go delete mode 100644 itests/net/connection.go diff --git a/itests/clients/grpc_client.go b/itests/clients/grpc_client.go index bb68da291..2cf1eadbb 100644 --- a/itests/clients/grpc_client.go +++ b/itests/clients/grpc_client.go @@ -94,6 +94,11 @@ func (c *GRPCClient) GetAssetsInfo(t *testing.T, id []byte) *g.AssetInfoResponse return assetInfo } +func (c *GRPCClient) Close(t testing.TB) { + err := c.conn.Close() + assert.NoError(t, err, "failed to close GRPC connection to %s node", c.impl.String()) +} + func (c *GRPCClient) getBalance(t *testing.T, req *g.BalancesRequest) *g.BalanceResponse { ctx, cancel := context.WithTimeout(context.Background(), c.timeout) defer cancel() diff --git a/itests/clients/net_client.go b/itests/clients/net_client.go new file mode 100644 index 000000000..e20524c41 --- /dev/null +++ b/itests/clients/net_client.go @@ -0,0 +1,212 @@ +package clients + +import ( + "bytes" + "context" + "encoding/base64" + "log/slog" + "net" + "sync" + "testing" + "time" + + "github.com/neilotoole/slogt" + "github.com/stretchr/testify/require" + + "github.com/wavesplatform/gowaves/itests/config" + "github.com/wavesplatform/gowaves/pkg/networking" + "github.com/wavesplatform/gowaves/pkg/proto" +) + +const ( + appName = "wavesL" + nonce = uint64(0) + networkTimeout = 3 * time.Second +) + +type NetClient struct { + ctx context.Context + t testing.TB + impl Implementation + n *networking.Network + c *networking.Config + s *networking.Session + + closedLock sync.Mutex + closed bool +} + +func NewNetClient( + ctx context.Context, t testing.TB, impl Implementation, port string, peers []proto.PeerInfo, +) *NetClient { + n := networking.NewNetwork() + p := newProtocol(nil) + h := newHandler(t, peers) + log := slogt.New(t) + conf := networking.NewConfig(p, h). + WithLogger(log). + WithWriteTimeout(networkTimeout). + WithSlogAttribute(slog.String("suite", t.Name())). + WithSlogAttribute(slog.String("impl", impl.String())) + + conn, err := net.Dial("tcp", config.DefaultIP+":"+port) + require.NoError(t, err, "failed to dial TCP to %s node", impl.String()) + + s, err := n.NewSession(ctx, conn, conf) + require.NoError(t, err, "failed to establish new session to %s node", impl.String()) + + cli := &NetClient{ctx: ctx, t: t, impl: impl, n: n, c: conf, s: s} + h.client = cli // Set client reference in handler. + return cli +} + +func (c *NetClient) SendHandshake() { + handshake := &proto.Handshake{ + AppName: appName, + Version: proto.ProtocolVersion(), + NodeName: "itest", + NodeNonce: nonce, + DeclaredAddr: proto.HandshakeTCPAddr{}, + Timestamp: proto.NewTimestampFromTime(time.Now()), + } + buf := bytes.NewBuffer(nil) + _, err := handshake.WriteTo(buf) + require.NoError(c.t, err, + "failed to marshal handshake to %s node at %q", c.impl.String(), c.s.RemoteAddr()) + _, err = c.s.Write(buf.Bytes()) + require.NoError(c.t, err, + "failed to send handshake to %s node at %q", c.impl.String(), c.s.RemoteAddr()) +} + +func (c *NetClient) SendMessage(m proto.Message) { + // TODO: Postpone message sending if the connection is reconnecting. + b, err := m.MarshalBinary() + require.NoError(c.t, err, "failed to marshal message to %s node at %q", c.impl.String(), c.s.RemoteAddr()) + _, err = c.s.Write(b) + require.NoError(c.t, err, "failed to send message to %s node at %q", c.impl.String(), c.s.RemoteAddr()) +} + +func (c *NetClient) Close() { + c.closedLock.Lock() + defer c.closedLock.Unlock() + if c.closed { + return + } + c.closed = true + _ = c.s.Close() +} + +func (c *NetClient) reconnect() { + // Check if the client was manually closed, in which case we don't want to reconnect. + c.closedLock.Lock() + defer c.closedLock.Unlock() + if c.closed { + return + } + c.t.Logf("Reconnecting to %q", c.s.RemoteAddr().String()) + conn, err := net.Dial("tcp", c.s.RemoteAddr().String()) + require.NoError(c.t, err, "failed to dial TCP to %s node", c.impl.String()) + + s, err := c.n.NewSession(c.ctx, conn, c.c) + require.NoError(c.t, err, "failed to re-establish the session to %s node", c.impl.String()) + c.s = s + + c.SendHandshake() +} + +type protocol struct { + dropLock sync.Mutex + drop map[proto.PeerMessageID]struct{} +} + +func newProtocol(drop []proto.PeerMessageID) *protocol { + m := make(map[proto.PeerMessageID]struct{}) + for _, id := range drop { + m[id] = struct{}{} + } + return &protocol{drop: m} +} + +func (p *protocol) EmptyHandshake() networking.Handshake { + return &proto.Handshake{} +} + +func (p *protocol) EmptyHeader() networking.Header { + return &proto.Header{} +} + +func (p *protocol) Ping() ([]byte, error) { + msg := &proto.GetPeersMessage{} + return msg.MarshalBinary() +} + +func (p *protocol) IsAcceptableHandshake(h networking.Handshake) bool { + hs, ok := h.(*proto.Handshake) + if !ok { + return false + } + // Reject nodes with incorrect network bytes, unsupported protocol versions, + // or a zero nonce (indicating a self-connection). + if hs.AppName != appName || hs.Version.Cmp(proto.ProtocolVersion()) < 0 || hs.NodeNonce == 0 { + return false + } + return true +} + +func (p *protocol) IsAcceptableMessage(h networking.Header) bool { + hdr, ok := h.(*proto.Header) + if !ok { + return false + } + p.dropLock.Lock() + defer p.dropLock.Unlock() + _, ok = p.drop[hdr.ContentID] + return !ok +} + +type handler struct { + peers []proto.PeerInfo + t testing.TB + client *NetClient +} + +func newHandler(t testing.TB, peers []proto.PeerInfo) *handler { + return &handler{t: t, peers: peers} +} + +func (h *handler) OnReceive(s *networking.Session, data []byte) { + msg, err := proto.UnmarshalMessage(data) + if err != nil { // Fail test on unmarshal error. + h.t.Logf("Failed to unmarshal message from bytes: %q", base64.StdEncoding.EncodeToString(data)) + h.t.FailNow() + return + } + switch msg.(type) { // Only reply with peers on GetPeersMessage. + case *proto.GetPeersMessage: + h.t.Logf("Received GetPeersMessage from %q", s.RemoteAddr()) + rpl := &proto.PeersMessage{Peers: h.peers} + bts, mErr := rpl.MarshalBinary() + if mErr != nil { // Fail test on marshal error. + h.t.Logf("Failed to marshal peers message: %v", mErr) + h.t.FailNow() + return + } + if _, wErr := s.Write(bts); wErr != nil { + h.t.Logf("Failed to send peers message: %v", wErr) + h.t.FailNow() + return + } + default: + } +} + +func (h *handler) OnHandshake(_ *networking.Session, _ networking.Handshake) { + h.t.Logf("Connection to %s node at %q was established", h.client.impl.String(), h.client.s.RemoteAddr()) +} + +func (h *handler) OnClose(s *networking.Session) { + h.t.Logf("Connection to %q was closed", s.RemoteAddr()) + if h.client != nil { + h.client.reconnect() + } +} diff --git a/itests/clients/node_client.go b/itests/clients/node_client.go index 7cf28b9c8..e3597a4e6 100644 --- a/itests/clients/node_client.go +++ b/itests/clients/node_client.go @@ -10,8 +10,10 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/wavesplatform/gowaves/itests/config" d "github.com/wavesplatform/gowaves/itests/docker" "github.com/wavesplatform/gowaves/pkg/crypto" "github.com/wavesplatform/gowaves/pkg/proto" @@ -24,10 +26,19 @@ type NodesClients struct { ScalaClient *NodeUniversalClient } -func NewNodesClients(t *testing.T, goPorts, scalaPorts *d.PortConfig) *NodesClients { +func NewNodesClients(ctx context.Context, t *testing.T, goPorts, scalaPorts *d.PortConfig) *NodesClients { + sp, err := proto.NewPeerInfoFromString(config.DefaultIP + ":" + scalaPorts.BindPort) + require.NoError(t, err, "failed to create Scala peer info") + gp, err := proto.NewPeerInfoFromString(config.DefaultIP + ":" + goPorts.BindPort) + require.NoError(t, err, "failed to create Go peer info") + peers := []proto.PeerInfo{sp, gp} return &NodesClients{ - GoClient: NewNodeUniversalClient(t, NodeGo, goPorts.RESTAPIPort, goPorts.GRPCPort), - ScalaClient: NewNodeUniversalClient(t, NodeScala, scalaPorts.RESTAPIPort, scalaPorts.GRPCPort), + GoClient: NewNodeUniversalClient( + ctx, t, NodeGo, goPorts.RESTAPIPort, goPorts.GRPCPort, goPorts.BindPort, peers, + ), + ScalaClient: NewNodeUniversalClient( + ctx, t, NodeScala, scalaPorts.RESTAPIPort, scalaPorts.GRPCPort, scalaPorts.BindPort, peers, + ), } } @@ -273,6 +284,29 @@ func (c *NodesClients) SynchronizedWavesBalances( return r } +func (c *NodesClients) Handshake() { + c.GoClient.Connection.SendHandshake() + c.ScalaClient.Connection.SendHandshake() +} + +func (c *NodesClients) SendToNodes(t *testing.T, m proto.Message, scala bool) { + t.Logf("Sending message to Go node: %T", m) + c.GoClient.Connection.SendMessage(m) + t.Log("Message sent to Go node") + if scala { + t.Logf("Sending message to Scala node: %T", m) + c.ScalaClient.Connection.SendMessage(m) + t.Log("Message sent to Scala node") + } +} + +func (c *NodesClients) Close(t *testing.T) { + c.GoClient.GRPCClient.Close(t) + c.GoClient.Connection.Close() + c.ScalaClient.GRPCClient.Close(t) + c.ScalaClient.Connection.Close() +} + func (c *NodesClients) requestNodesAvailableBalances( ctx context.Context, address proto.WavesAddress, ) (addressedBalanceAtHeight, error) { diff --git a/itests/clients/universal_client.go b/itests/clients/universal_client.go index 9911e2015..32c20833a 100644 --- a/itests/clients/universal_client.go +++ b/itests/clients/universal_client.go @@ -1,19 +1,26 @@ package clients import ( + "context" "testing" + + "github.com/wavesplatform/gowaves/pkg/proto" ) type NodeUniversalClient struct { Implementation Implementation HTTPClient *HTTPClient GRPCClient *GRPCClient + Connection *NetClient } -func NewNodeUniversalClient(t *testing.T, impl Implementation, httpPort string, grpcPort string) *NodeUniversalClient { +func NewNodeUniversalClient( + ctx context.Context, t *testing.T, impl Implementation, httpPort, grpcPort, netPort string, peers []proto.PeerInfo, +) *NodeUniversalClient { return &NodeUniversalClient{ Implementation: impl, HTTPClient: NewHTTPClient(t, impl, httpPort), GRPCClient: NewGRPCClient(t, impl, grpcPort), + Connection: NewNetClient(ctx, t, impl, netPort, peers), } } diff --git a/itests/fixtures/base_fixtures.go b/itests/fixtures/base_fixtures.go index 23f359066..eb8f9a911 100644 --- a/itests/fixtures/base_fixtures.go +++ b/itests/fixtures/base_fixtures.go @@ -49,7 +49,8 @@ func (suite *BaseSuite) BaseSetup(options ...config.BlockchainOption) { suite.Require().NoError(ssErr, "couldn't start Scala node container") } - suite.Clients = clients.NewNodesClients(suite.T(), docker.GoNode().Ports(), docker.ScalaNode().Ports()) + suite.Clients = clients.NewNodesClients(suite.MainCtx, suite.T(), docker.GoNode().Ports(), docker.ScalaNode().Ports()) + suite.Clients.Handshake() } func (suite *BaseSuite) SetupSuite() { @@ -58,6 +59,7 @@ func (suite *BaseSuite) SetupSuite() { func (suite *BaseSuite) TearDownSuite() { suite.Clients.WaitForStateHashEquality(suite.T()) + suite.Clients.Close(suite.T()) suite.Docker.Finish(suite.Cancel) } diff --git a/itests/net/connection.go b/itests/net/connection.go deleted file mode 100644 index 7fafe65f1..000000000 --- a/itests/net/connection.go +++ /dev/null @@ -1,148 +0,0 @@ -package net - -import ( - "bufio" - stderrs "errors" - "net" - "testing" - "time" - - "github.com/cenkalti/backoff/v4" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - - "github.com/wavesplatform/gowaves/itests/config" - d "github.com/wavesplatform/gowaves/itests/docker" - "github.com/wavesplatform/gowaves/pkg/proto" -) - -type OutgoingPeer struct { - conn net.Conn -} - -func NewConnection(declAddr proto.TCPAddr, address string, ver proto.Version, wavesNetwork string) (op *OutgoingPeer, err error) { - c, err := net.Dial("tcp", address) - if err != nil { - return nil, errors.Wrapf(err, "failed to connect to %s", address) - } - defer func() { - if err != nil { - if closeErr := c.Close(); closeErr != nil { - err = errors.Wrap(err, closeErr.Error()) - } - } - }() - handshake := proto.Handshake{ - AppName: wavesNetwork, - Version: ver, - NodeName: "itest", - NodeNonce: 0x0, - DeclaredAddr: proto.HandshakeTCPAddr(declAddr), - Timestamp: proto.NewTimestampFromTime(time.Now()), - } - - _, err = handshake.WriteTo(c) - if err != nil { - return nil, errors.Wrapf(err, "failed to send handshake to %s", address) - } - - _, err = handshake.ReadFrom(bufio.NewReader(c)) - if err != nil { - return nil, errors.Wrapf(err, "failed to read handshake from %s", address) - } - - return &OutgoingPeer{conn: c}, nil -} - -func (a *OutgoingPeer) SendMessage(m proto.Message) error { - b, err := m.MarshalBinary() - if err != nil { - return err - } - - _, err = a.conn.Write(b) - if err != nil { - return errors.Wrapf(err, "failed to send message") - } - return nil -} - -func (a *OutgoingPeer) Close() error { - return a.conn.Close() -} - -type NodeConnections struct { - scalaCon *OutgoingPeer - goCon *OutgoingPeer -} - -func NewNodeConnections(goPorts, scalaPorts *d.PortConfig) (NodeConnections, error) { - var connections NodeConnections - err := retry(1*time.Second, func() error { - var err error - connections, err = establishConnections(goPorts, scalaPorts) - return err - }) - return connections, err -} - -func establishConnections(goPorts, scalaPorts *d.PortConfig) (NodeConnections, error) { - goCon, err := NewConnection( - proto.TCPAddr{}, - config.DefaultIP+":"+goPorts.BindPort, - proto.ProtocolVersion(), "wavesL", - ) - if err != nil { - return NodeConnections{}, errors.Wrap(err, "failed to create connection to go node") - } - scalaCon, err := NewConnection( - proto.TCPAddr{}, - config.DefaultIP+":"+scalaPorts.BindPort, - proto.ProtocolVersion(), "wavesL", - ) - if err != nil { - if closeErr := goCon.Close(); closeErr != nil { - return NodeConnections{}, errors.Wrap(stderrs.Join(closeErr, err), - "failed to create connection to scala node and close go node connection") - } - return NodeConnections{}, errors.Wrap(err, "failed to create connection to scala node") - } - return NodeConnections{scalaCon: scalaCon, goCon: goCon}, nil -} - -func retry(timeout time.Duration, f func() error) error { - bo := backoff.NewExponentialBackOff() - bo.InitialInterval = 100 * time.Millisecond - bo.MaxInterval = 500 * time.Millisecond - bo.MaxElapsedTime = timeout - if err := backoff.Retry(f, bo); err != nil { - if bo.NextBackOff() == backoff.Stop { - return errors.Wrap(err, "reached retry deadline") - } - return err - } - return nil -} - -func (c *NodeConnections) SendToNodes(t *testing.T, m proto.Message, scala bool) { - t.Logf("Sending message to go node: %T", m) - err := c.goCon.SendMessage(m) - assert.NoError(t, err, "failed to send TransactionMessage to go node") - t.Log("Message sent to go node") - if scala { - t.Logf("Sending message to scala node: %T", m) - err = c.scalaCon.SendMessage(m) - assert.NoError(t, err, "failed to send TransactionMessage to scala node") - t.Log("Message sent to scala node") - } -} - -func (c *NodeConnections) Close(t *testing.T) { - t.Log("Closing connections") - err := c.goCon.Close() - assert.NoError(t, err, "failed to close go node connection") - - err = c.scalaCon.Close() - assert.NoError(t, err, "failed to close scala node connection") - t.Log("Connections closed") -} diff --git a/itests/utilities/common.go b/itests/utilities/common.go index 9c7a90a54..960823072 100644 --- a/itests/utilities/common.go +++ b/itests/utilities/common.go @@ -23,7 +23,6 @@ import ( "github.com/wavesplatform/gowaves/itests/config" f "github.com/wavesplatform/gowaves/itests/fixtures" - "github.com/wavesplatform/gowaves/itests/net" "github.com/wavesplatform/gowaves/pkg/client" "github.com/wavesplatform/gowaves/pkg/crypto" g "github.com/wavesplatform/gowaves/pkg/grpc/generated/waves/node/grpc" @@ -671,11 +670,7 @@ func SendAndWaitTransaction(suite *f.BaseSuite, tx proto.Transaction, scheme pro } scala := !waitForTx - connections, err := net.NewNodeConnections(suite.Docker.GoNode().Ports(), suite.Docker.ScalaNode().Ports()) - suite.Require().NoError(err, "failed to create new node connections") - defer connections.Close(suite.T()) - - connections.SendToNodes(suite.T(), txMsg, scala) + suite.Clients.SendToNodes(suite.T(), txMsg, scala) suite.T().Log("Tx msg was successfully send to nodes") suite.T().Log("Waiting for Tx appears in Blockchain") diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 189a59ce6..cfbed99ab 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -299,7 +299,8 @@ func (s *Session) readMessage(hdr Header) error { // Read the header if _, err := hdr.ReadFrom(s.bufRead); err != nil { if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "closed") || - strings.Contains(err.Error(), "reset by peer") { + strings.Contains(err.Error(), "reset by peer") || + strings.Contains(err.Error(), "broken pipe") { // In Docker network built on top of pipe, we get this error on close. return ErrConnectionClosedOnRead } s.logger.Error("Failed to read header", "error", err) @@ -322,6 +323,7 @@ func (s *Session) readMessage(hdr Header) error { func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { // Wrap in a limited reader + s.logger.Debug("Reading message payload", "len", hdr.PayloadLength()) conn = &io.LimitedReader{R: conn, N: int64(hdr.PayloadLength())} // Copy into buffer @@ -337,15 +339,19 @@ func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { s.logger.Error("Failed to write header to receiving buffer", "error", err) return err } - _, err = io.Copy(s.receiveBuffer, conn) + n, err := io.Copy(s.receiveBuffer, conn) if err != nil { s.logger.Error("Failed to copy payload to receiving buffer", "error", err) return err } + s.logger.Debug("Message payload successfully read", "len", n) + // We lock the buffer from modification on the time of invocation of OnReceive handler. // The slice of bytes passed into the handler is only valid for the duration of the handler invocation. // So inside the handler better deserialize message or make a copy of the bytes. + s.logger.Debug("Invoking OnReceive handler", "message", base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) s.config.handler.OnReceive(s, s.receiveBuffer.Bytes()) // Invoke OnReceive handler. + s.receiveBuffer.Reset() // Reset the buffer for the next message. return nil } diff --git a/pkg/p2p/conn/conn.go b/pkg/p2p/conn/conn.go index b477bbe55..d2de99d95 100644 --- a/pkg/p2p/conn/conn.go +++ b/pkg/p2p/conn/conn.go @@ -124,12 +124,12 @@ func receiveFromRemote(conn deadlineReader, fromRemoteCh chan *bytebufferpool.By return errors.Wrap(err, "failed to read header") } // received too big message, probably it's an error - if l := int(header.HeaderLength() + header.PayloadLength); l > maxMessageSize { + if l := int(header.HeaderLength() + header.PayloadLength()); l > maxMessageSize { return errors.Errorf("received too long message, size=%d > max=%d", l, maxMessageSize) } if skip(header) { - if _, err := io.CopyN(io.Discard, reader, int64(header.PayloadLength)); err != nil { + if _, err := io.CopyN(io.Discard, reader, int64(header.PayloadLength())); err != nil { return errors.Wrap(err, "failed to skip payload") } continue @@ -142,7 +142,7 @@ func receiveFromRemote(conn deadlineReader, fromRemoteCh chan *bytebufferpool.By return errors.Wrap(err, "failed to write header into buff") } // then read all message to remaining buffer - if _, err := io.CopyN(b, reader, int64(header.PayloadLength)); err != nil { + if _, err := io.CopyN(b, reader, int64(header.PayloadLength())); err != nil { bytebufferpool.Put(b) return errors.Wrap(err, "failed to read payload into buffer") } diff --git a/pkg/proto/microblock.go b/pkg/proto/microblock.go index a02f6657c..ceb9c6048 100644 --- a/pkg/proto/microblock.go +++ b/pkg/proto/microblock.go @@ -11,6 +11,7 @@ import ( g "github.com/wavesplatform/gowaves/pkg/grpc/generated/waves" "github.com/wavesplatform/gowaves/pkg/libs/deserializer" "github.com/wavesplatform/gowaves/pkg/libs/serializer" + "github.com/wavesplatform/gowaves/pkg/util/common" ) const ( @@ -278,7 +279,7 @@ func (a *MicroBlockMessage) UnmarshalBinary(data []byte) error { if len(data) < crypto.SignatureSize*2+1 { return errors.New("invalid micro block size") } - b := make([]byte, len(data[:h.PayloadLength])) + b := make([]byte, len(data[:h.payloadLength])) copy(b, data) a.Body = b @@ -311,7 +312,7 @@ func (a *MicroBlockInvMessage) WriteTo(w io.Writer) (n int64, err error) { h.Length = maxHeaderLength + uint32(len(a.Body)) - 4 h.Magic = headerMagic h.ContentID = ContentIDInvMicroblock - h.PayloadLength = uint32(len(a.Body)) + h.payloadLength = common.SafeIntToUint32(len(a.Body)) dig, err := crypto.FastHash(a.Body) if err != nil { return 0, err @@ -351,10 +352,10 @@ func (a *MicroBlockRequestMessage) ReadFrom(_ io.Reader) (n int64, err error) { func (a *MicroBlockRequestMessage) WriteTo(w io.Writer) (int64, error) { var h Header - h.Length = maxHeaderLength + uint32(len(a.TotalBlockSig)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(a.TotalBlockSig)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDMicroblockRequest - h.PayloadLength = uint32(len(a.TotalBlockSig)) + h.payloadLength = common.SafeIntToUint32(len(a.TotalBlockSig)) dig, err := crypto.FastHash(a.TotalBlockSig) if err != nil { return 0, err @@ -393,7 +394,7 @@ func (a *MicroBlockRequestMessage) UnmarshalBinary(data []byte) error { return errors.Errorf("wrong ContentID in Header: %x", h.ContentID) } data = data[17:] - body := make([]byte, h.PayloadLength) + body := make([]byte, h.payloadLength) copy(body, data) a.TotalBlockSig = body return nil @@ -517,8 +518,8 @@ func (a *MicroBlockInvMessage) UnmarshalBinary(data []byte) error { return errors.Errorf("wrong ContentID in Header: %x", h.ContentID) } data = data[17:] - body := make([]byte, h.PayloadLength) - copy(body, data[:h.PayloadLength]) + body := make([]byte, h.payloadLength) + copy(body, data[:h.payloadLength]) a.Body = body return nil } @@ -563,15 +564,15 @@ func (a *PBMicroBlockMessage) UnmarshalBinary(data []byte) error { if h.ContentID != ContentIDPBMicroBlock { return errors.Errorf("wrong ContentID in Header: %x", h.ContentID) } - if h.PayloadLength < crypto.DigestSize { + if h.payloadLength < crypto.DigestSize { return errors.New("PBMicroBlockMessage UnmarshalBinary: invalid data size") } data = data[17:] - if uint32(len(data)) < h.PayloadLength { + if common.SafeIntToUint32(len(data)) < h.payloadLength { return errors.New("invalid data size") } - mbBytes := data[:h.PayloadLength] + mbBytes := data[:h.payloadLength] a.MicroBlockBytes = make([]byte, len(mbBytes)) copy(a.MicroBlockBytes, mbBytes) return nil diff --git a/pkg/proto/proto.go b/pkg/proto/proto.go index 8a917301f..28e8f82af 100644 --- a/pkg/proto/proto.go +++ b/pkg/proto/proto.go @@ -15,6 +15,7 @@ import ( "github.com/wavesplatform/gowaves/pkg/crypto" "github.com/wavesplatform/gowaves/pkg/util/collect_writes" + "github.com/wavesplatform/gowaves/pkg/util/common" ) const ( @@ -73,7 +74,7 @@ type Header struct { Length uint32 Magic uint32 ContentID PeerMessageID - PayloadLength uint32 + payloadLength uint32 PayloadChecksum [headerChecksumLen]byte } @@ -96,7 +97,7 @@ func (h *Header) WriteTo(w io.Writer) (int64, error) { } func (h *Header) HeaderLength() uint32 { - if h.PayloadLength > 0 { + if h.payloadLength > 0 { return headerSizeWithPayload } return headerSizeWithoutPayload @@ -132,8 +133,8 @@ func (h *Header) UnmarshalBinary(data []byte) error { return fmt.Errorf("received wrong magic: want %x, have %x", headerMagic, h.Magic) } h.ContentID = PeerMessageID(data[HeaderContentIDPosition]) - h.PayloadLength = binary.BigEndian.Uint32(data[9:headerSizeWithoutPayload]) - if h.PayloadLength > 0 { + h.payloadLength = binary.BigEndian.Uint32(data[9:headerSizeWithoutPayload]) + if h.payloadLength > 0 { if uint32(len(data)) < headerSizeWithPayload { return errors.New("Header UnmarshalBinary: invalid data size") } @@ -150,8 +151,8 @@ func (h *Header) Copy(data []byte) (int, error) { binary.BigEndian.PutUint32(data[0:4], h.Length) binary.BigEndian.PutUint32(data[4:8], headerMagic) data[HeaderContentIDPosition] = byte(h.ContentID) - binary.BigEndian.PutUint32(data[9:headerSizeWithoutPayload], h.PayloadLength) - if h.PayloadLength > 0 { + binary.BigEndian.PutUint32(data[9:headerSizeWithoutPayload], h.payloadLength) + if h.payloadLength > 0 { if len(data) < headerSizeWithPayload { return 0, errors.New("Header Copy: invalid data size") } @@ -161,6 +162,10 @@ func (h *Header) Copy(data []byte) (int, error) { return headerSizeWithoutPayload, nil } +func (h *Header) PayloadLength() uint32 { + return h.payloadLength +} + // Version represents the version of the protocol type Version struct { _ struct{} // this field disallows raw struct initialization @@ -498,10 +503,6 @@ func (a HandshakeTCPAddr) Network() string { return "tcp" } -func ParseHandshakeTCPAddr(s string) HandshakeTCPAddr { - return HandshakeTCPAddr(NewTCPAddrFromString(s)) -} - type U8String struct { S string } @@ -655,7 +656,7 @@ func (m *GetPeersMessage) MarshalBinary() ([]byte, error) { h.Length = maxHeaderLength - 8 h.Magic = headerMagic h.ContentID = ContentIDGetPeers - h.PayloadLength = 0 + h.payloadLength = 0 return h.MarshalBinary() } @@ -671,7 +672,7 @@ func (m *GetPeersMessage) UnmarshalBinary(b []byte) error { if header.ContentID != ContentIDGetPeers { return fmt.Errorf("getpeers message ContentID is unexpected: want %x have %x", ContentIDGetPeers, header.ContentID) } - if header.PayloadLength != 0 { + if header.payloadLength != 0 { return fmt.Errorf("getpeers message length is not zero") } @@ -909,10 +910,10 @@ func (m *PeersMessage) WriteTo(w io.Writer) (int64, error) { return n, err } - h.Length = maxHeaderLength + uint32(len(buf.Bytes())) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(buf.Bytes())) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDPeers - h.PayloadLength = uint32(len(buf.Bytes())) + h.payloadLength = common.SafeIntToUint32(len(buf.Bytes())) dig, err := crypto.FastHash(buf.Bytes()) if err != nil { return 0, err @@ -1050,7 +1051,7 @@ func (m *GetSignaturesMessage) MarshalBinary() ([]byte, error) { h.Length = maxHeaderLength + uint32(len(body)) - 4 h.Magic = headerMagic h.ContentID = ContentIDGetSignatures - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1132,16 +1133,16 @@ type SignaturesMessage struct { // MarshalBinary encodes SignaturesMessage to binary form func (m *SignaturesMessage) MarshalBinary() ([]byte, error) { body := make([]byte, 4, 4+len(m.Signatures)) - binary.BigEndian.PutUint32(body[0:4], uint32(len(m.Signatures))) + binary.BigEndian.PutUint32(body[0:4], common.SafeIntToUint32(len(m.Signatures))) for _, b := range m.Signatures { body = append(body, b[:]...) } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDSignatures - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1225,10 +1226,10 @@ func (m *GetBlockMessage) MarshalBinary() ([]byte, error) { body := m.BlockID.Bytes() var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDGetBlock - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1269,10 +1270,11 @@ func parsePacket(data []byte, ContentID PeerMessageID, name string, f func(paylo if h.ContentID != ContentID { return fmt.Errorf("%s: wrong ContentID in Header: %x", name, h.ContentID) } - if len(data) < int(17+h.PayloadLength) { - return fmt.Errorf("%s: expected data at least %d, found %d", name, 17+h.PayloadLength, len(data)) + if len(data) < int(headerSizeWithPayload+h.payloadLength) { + return fmt.Errorf("%s: expected data at least %d, found %d", + name, headerSizeWithPayload+h.payloadLength, len(data)) } - err := f(data[17 : 17+h.PayloadLength]) + err := f(data[headerSizeWithPayload : headerSizeWithPayload+h.payloadLength]) if err != nil { return errors.Wrapf(err, "%s payload error", name) } @@ -1320,10 +1322,10 @@ type BlockMessage struct { // MarshalBinary encodes BlockMessage to binary form func (m *BlockMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.BlockBytes)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.BlockBytes)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDBlock - h.PayloadLength = uint32(len(m.BlockBytes)) + h.payloadLength = common.SafeIntToUint32(len(m.BlockBytes)) dig, err := crypto.FastHash(m.BlockBytes) if err != nil { return nil, err @@ -1340,10 +1342,10 @@ func (m *BlockMessage) MarshalBinary() ([]byte, error) { func MakeHeader(contentID PeerMessageID, payload []byte) (Header, error) { var h Header - h.Length = maxHeaderLength + uint32(len(payload)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(payload)) - headerChecksumLen h.Magic = headerMagic h.ContentID = contentID - h.PayloadLength = uint32(len(payload)) + h.payloadLength = common.SafeIntToUint32(len(payload)) dig, err := crypto.FastHash(payload) if err != nil { return Header{}, err @@ -1365,11 +1367,11 @@ func (m *BlockMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - if uint32(len(data)) < 17+h.PayloadLength { + if common.SafeIntToUint32(len(data)) < 17+h.payloadLength { return errors.New("BlockMessage UnmarshalBinary: invalid data size") } - m.BlockBytes = make([]byte, h.PayloadLength) - copy(m.BlockBytes, data[17:17+h.PayloadLength]) + m.BlockBytes = make([]byte, h.payloadLength) + copy(m.BlockBytes, data[17:17+h.payloadLength]) return nil } @@ -1403,10 +1405,10 @@ type ScoreMessage struct { // MarshalBinary encodes ScoreMessage to binary form func (m *ScoreMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.Score)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.Score)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDScore - h.PayloadLength = uint32(len(m.Score)) + h.payloadLength = common.SafeIntToUint32(len(m.Score)) dig, err := crypto.FastHash(m.Score) if err != nil { return nil, err @@ -1434,11 +1436,11 @@ func (m *ScoreMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - if uint32(len(data)) < 17+h.PayloadLength { + if common.SafeIntToUint32(len(data)) < 17+h.payloadLength { return errors.New("invalid data size") } - m.Score = make([]byte, h.PayloadLength) - copy(m.Score, data[17:17+h.PayloadLength]) + m.Score = make([]byte, h.payloadLength) + copy(m.Score, data[17:17+h.payloadLength]) return nil } @@ -1470,10 +1472,10 @@ type TransactionMessage struct { // MarshalBinary encodes TransactionMessage to binary form func (m *TransactionMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.Transaction)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.Transaction)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDTransaction - h.PayloadLength = uint32(len(m.Transaction)) + h.payloadLength = common.SafeIntToUint32(len(m.Transaction)) dig, err := crypto.FastHash(m.Transaction) if err != nil { return nil, err @@ -1498,11 +1500,11 @@ func (m *TransactionMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } // TODO check max length - if uint32(len(data)) < maxHeaderLength+h.PayloadLength { + if common.SafeIntToUint32(len(data)) < maxHeaderLength+h.payloadLength { return errors.New("invalid data size") } - m.Transaction = make([]byte, h.PayloadLength) - copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + m.Transaction = make([]byte, h.payloadLength) + copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) dig, err := crypto.FastHash(m.Transaction) if err != nil { return err @@ -1558,10 +1560,10 @@ func (m *CheckPointMessage) MarshalBinary() ([]byte, error) { } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDCheckpoint - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1642,10 +1644,10 @@ type PBBlockMessage struct { // MarshalBinary encodes PBBlockMessage to binary form func (m *PBBlockMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.PBBlockBytes)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.PBBlockBytes)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDPBBlock - h.PayloadLength = uint32(len(m.PBBlockBytes)) + h.payloadLength = common.SafeIntToUint32(len(m.PBBlockBytes)) dig, err := crypto.FastHash(m.PBBlockBytes) if err != nil { return nil, err @@ -1673,11 +1675,11 @@ func (m *PBBlockMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - m.PBBlockBytes = make([]byte, h.PayloadLength) - if uint32(len(data)) < 17+h.PayloadLength { + m.PBBlockBytes = make([]byte, h.payloadLength) + if common.SafeIntToUint32(len(data)) < 17+h.payloadLength { return errors.New("PBBlockMessage UnmarshalBinary: invalid data size") } - copy(m.PBBlockBytes, data[17:17+h.PayloadLength]) + copy(m.PBBlockBytes, data[17:17+h.payloadLength]) return nil } @@ -1711,10 +1713,10 @@ type PBTransactionMessage struct { // MarshalBinary encodes PBTransactionMessage to binary form func (m *PBTransactionMessage) MarshalBinary() ([]byte, error) { var h Header - h.Length = maxHeaderLength + uint32(len(m.Transaction)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(m.Transaction)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDPBTransaction - h.PayloadLength = uint32(len(m.Transaction)) + h.payloadLength = common.SafeIntToUint32(len(m.Transaction)) dig, err := crypto.FastHash(m.Transaction) if err != nil { return nil, err @@ -1739,11 +1741,11 @@ func (m *PBTransactionMessage) UnmarshalBinary(data []byte) error { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } // TODO check max length - m.Transaction = make([]byte, h.PayloadLength) - if uint32(len(data)) < maxHeaderLength+h.PayloadLength { + m.Transaction = make([]byte, h.payloadLength) + if common.SafeIntToUint32(len(data)) < maxHeaderLength+h.payloadLength { return errors.New("PBTransactionMessage UnmarshalBinary: invalid data size") } - copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + copy(m.Transaction, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) dig, err := crypto.FastHash(m.Transaction) if err != nil { return err @@ -1851,10 +1853,10 @@ func (m *GetBlockIdsMessage) MarshalBinary() ([]byte, error) { } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDGetBlockIDs - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -1947,10 +1949,10 @@ func (m *BlockIdsMessage) MarshalBinary() ([]byte, error) { } var h Header - h.Length = maxHeaderLength + uint32(len(body)) - 4 + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = ContentIDBlockIDs - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return nil, err @@ -2063,10 +2065,10 @@ type MiningLimits struct { func buildHeader(body []byte, messID PeerMessageID) (Header, error) { var h Header - h.Length = maxHeaderLength + uint32(len(body)) - headerChecksumLen + h.Length = maxHeaderLength + common.SafeIntToUint32(len(body)) - headerChecksumLen h.Magic = headerMagic h.ContentID = messID - h.PayloadLength = uint32(len(body)) + h.payloadLength = common.SafeIntToUint32(len(body)) dig, err := crypto.FastHash(body) if err != nil { return Header{}, err @@ -2161,8 +2163,8 @@ func (m *BlockSnapshotMessage) UnmarshalBinary(data []byte) error { if h.ContentID != ContentIDBlockSnapshot { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - m.Bytes = make([]byte, h.PayloadLength) - copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + m.Bytes = make([]byte, h.payloadLength) + copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) return nil } @@ -2220,8 +2222,8 @@ func (m *MicroBlockSnapshotMessage) UnmarshalBinary(data []byte) error { if h.ContentID != ContentIDMicroBlockSnapshot { return fmt.Errorf("wrong ContentID in Header: %x", h.ContentID) } - m.Bytes = make([]byte, h.PayloadLength) - copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.PayloadLength]) + m.Bytes = make([]byte, h.payloadLength) + copy(m.Bytes, data[maxHeaderLength:maxHeaderLength+h.payloadLength]) return nil } diff --git a/pkg/ride/math/math_test.go b/pkg/ride/math/math_test.go index e85da6d19..c237b0c66 100644 --- a/pkg/ride/math/math_test.go +++ b/pkg/ride/math/math_test.go @@ -21,6 +21,7 @@ func TestFraction(t *testing.T) { }{ {-6, 6301369, 100, false, -378082}, {6, 6301369, 100, false, 378082}, + {4445280, 1, 1440, false, 3087}, {6, 6301369, 0, true, 0}, } { r, err := Fraction(tc.value, tc.numerator, tc.denominator) diff --git a/pkg/util/common/util.go b/pkg/util/common/util.go index 0aff46164..b92f0fa3a 100644 --- a/pkg/util/common/util.go +++ b/pkg/util/common/util.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "math" "math/big" "os/user" "path/filepath" @@ -244,3 +245,13 @@ func padBytes(p byte, bytes []byte) []byte { copy(r[1:], bytes) return r } + +func SafeIntToUint32(v int) uint32 { + if v < 0 { + panic("negative value") + } + if v > math.MaxUint32 { + panic("value is too big") + } + return uint32(v) +} From 948cc56e764875e610dcc8bfa1cca3cba50806ec Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 25 Nov 2024 16:58:01 +0400 Subject: [PATCH 05/30] Fixed NetClient closing issue. Configuration option to set KeepAliveInterval added to networking.Config. --- itests/clients/net_client.go | 34 +++++++++++++++++++++------------ pkg/networking/configuration.go | 5 +++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/itests/clients/net_client.go b/itests/clients/net_client.go index e20524c41..dc7e9213a 100644 --- a/itests/clients/net_client.go +++ b/itests/clients/net_client.go @@ -22,6 +22,7 @@ const ( appName = "wavesL" nonce = uint64(0) networkTimeout = 3 * time.Second + pingInterval = 5 * time.Second ) type NetClient struct { @@ -32,8 +33,10 @@ type NetClient struct { c *networking.Config s *networking.Session - closedLock sync.Mutex - closed bool + closingLock sync.Mutex + closingFlag bool + closedLock sync.Mutex + closedFlag bool } func NewNetClient( @@ -46,6 +49,7 @@ func NewNetClient( conf := networking.NewConfig(p, h). WithLogger(log). WithWriteTimeout(networkTimeout). + WithKeepAliveInterval(pingInterval). WithSlogAttribute(slog.String("suite", t.Name())). WithSlogAttribute(slog.String("impl", impl.String())) @@ -79,7 +83,6 @@ func (c *NetClient) SendHandshake() { } func (c *NetClient) SendMessage(m proto.Message) { - // TODO: Postpone message sending if the connection is reconnecting. b, err := m.MarshalBinary() require.NoError(c.t, err, "failed to marshal message to %s node at %q", c.impl.String(), c.s.RemoteAddr()) _, err = c.s.Write(b) @@ -87,22 +90,23 @@ func (c *NetClient) SendMessage(m proto.Message) { } func (c *NetClient) Close() { + c.t.Logf("Trying to close connection to %s node at %q", c.impl.String(), c.s.RemoteAddr().String()) + + c.closingLock.Lock() + c.closingFlag = true + c.closingLock.Unlock() + c.closedLock.Lock() defer c.closedLock.Unlock() - if c.closed { + c.t.Logf("Closing connection to %s node at %q (%t)", c.impl.String(), c.s.RemoteAddr().String(), c.closedFlag) + if c.closedFlag { return } - c.closed = true _ = c.s.Close() + c.closedFlag = true } func (c *NetClient) reconnect() { - // Check if the client was manually closed, in which case we don't want to reconnect. - c.closedLock.Lock() - defer c.closedLock.Unlock() - if c.closed { - return - } c.t.Logf("Reconnecting to %q", c.s.RemoteAddr().String()) conn, err := net.Dial("tcp", c.s.RemoteAddr().String()) require.NoError(c.t, err, "failed to dial TCP to %s node", c.impl.String()) @@ -114,6 +118,12 @@ func (c *NetClient) reconnect() { c.SendHandshake() } +func (c *NetClient) closing() bool { + c.closingLock.Lock() + defer c.closingLock.Unlock() + return c.closingFlag +} + type protocol struct { dropLock sync.Mutex drop map[proto.PeerMessageID]struct{} @@ -206,7 +216,7 @@ func (h *handler) OnHandshake(_ *networking.Session, _ networking.Handshake) { func (h *handler) OnClose(s *networking.Session) { h.t.Logf("Connection to %q was closed", s.RemoteAddr()) - if h.client != nil { + if !h.client.closing() && h.client != nil { h.client.reconnect() } } diff --git a/pkg/networking/configuration.go b/pkg/networking/configuration.go index 46fb53f8b..be261d3c9 100644 --- a/pkg/networking/configuration.go +++ b/pkg/networking/configuration.go @@ -57,3 +57,8 @@ func (c *Config) WithKeepAliveDisabled() *Config { c.keepAlive = false return c } + +func (c *Config) WithKeepAliveInterval(interval time.Duration) *Config { + c.keepAliveInterval = interval + return c +} From 67f4b853176ba5832cd124b1c80fc3aecae370dd Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 25 Nov 2024 19:12:40 +0400 Subject: [PATCH 06/30] Redundant log removed. --- itests/clients/node_client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/itests/clients/node_client.go b/itests/clients/node_client.go index e3597a4e6..95a54d046 100644 --- a/itests/clients/node_client.go +++ b/itests/clients/node_client.go @@ -247,7 +247,6 @@ func (c *NodesClients) SynchronizedWavesBalances( if err != nil { t.Logf("Errors while requesting balances: %v", err) } - t.Log("Entering loop") for { commonHeight := mostCommonHeight(sbs) toRetry := make([]proto.WavesAddress, 0, len(addresses)) From d44fa7ff65f450a25c07cdbb1d2236d268cb084b Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 25 Nov 2024 20:39:14 +0400 Subject: [PATCH 07/30] Move save int conversion to safecast lib. --- go.mod | 1 + go.sum | 2 ++ pkg/util/common/util.go | 12 +++++------- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 918d874a9..bc8fa47d5 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( filippo.io/edwards25519 v1.1.0 github.com/beevik/ntp v1.4.3 github.com/btcsuite/btcd/btcec/v2 v2.3.4 + github.com/ccoveille/go-safecast v1.2.0 github.com/cenkalti/backoff/v4 v4.3.0 github.com/cespare/xxhash/v2 v2.3.0 github.com/consensys/gnark v0.11.0 diff --git a/go.sum b/go.sum index a7b27b016..84dca2c9c 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurT github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 h1:59Kx4K6lzOW5w6nFlA0v5+lk/6sjybR934QNHSJZPTQ= github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= +github.com/ccoveille/go-safecast v1.2.0 h1:H4X7aosepsU1Mfk+098CTdKpsDH0cfYJ2RmwXFjgvfc= +github.com/ccoveille/go-safecast v1.2.0/go.mod h1:QqwNjxQ7DAqY0C721OIO9InMk9zCwcsO7tnRuHytad8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= diff --git a/pkg/util/common/util.go b/pkg/util/common/util.go index b92f0fa3a..9d4e854ba 100644 --- a/pkg/util/common/util.go +++ b/pkg/util/common/util.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/hex" "fmt" - "math" "math/big" "os/user" "path/filepath" @@ -13,6 +12,7 @@ import ( "strings" "time" + "github.com/ccoveille/go-safecast" "github.com/mr-tron/base58/base58" "github.com/pkg/errors" "golang.org/x/exp/constraints" @@ -247,11 +247,9 @@ func padBytes(p byte, bytes []byte) []byte { } func SafeIntToUint32(v int) uint32 { - if v < 0 { - panic("negative value") - } - if v > math.MaxUint32 { - panic("value is too big") + r, err := safecast.ToUint32(v) + if err != nil { + panic(err) } - return uint32(v) + return r } From df01e6039df8a7aba089e305d9508feeca3ea722 Mon Sep 17 00:00:00 2001 From: Nikolay Eskov Date: Thu, 28 Nov 2024 21:43:57 +0300 Subject: [PATCH 08/30] Fix data race error in 'networking_test' package Implement 'io.Stringer' for 'Session' struct. Data race happens because 'clientHandler' mock in 'TestSessionTimeoutOnHandshake' test reads 'Session' structure at the same time as 'clientSession.Close' call. --- pkg/networking/session.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index cfbed99ab..b0f760555 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -6,6 +6,7 @@ import ( "context" "encoding/base64" "errors" + "fmt" "io" "log/slog" "net" @@ -85,6 +86,10 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp return s, nil } +func (s *Session) String() string { + return fmt.Sprintf("Session{local=%s,remote=%s}", s.LocalAddr(), s.RemoteAddr()) +} + // LocalAddr returns the local network address. func (s *Session) LocalAddr() net.Addr { if a, ok := s.conn.(addressable); ok { From 00a9ebe1165f0ce7d2d2fab7df64849969dfb78d Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Tue, 10 Dec 2024 18:50:46 +0400 Subject: [PATCH 09/30] Replace atomic.Uint32 with atomic.Bool and use CompareAndSwap there it's possible. Replace random delay with constan to make test not blink. Simplify assertion in test to make it stable. --- pkg/execution/taskgroup.go | 13 +++++-------- pkg/execution/taskgroup_test.go | 8 +------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pkg/execution/taskgroup.go b/pkg/execution/taskgroup.go index 387d4e007..6550a257f 100644 --- a/pkg/execution/taskgroup.go +++ b/pkg/execution/taskgroup.go @@ -14,12 +14,12 @@ import ( type TaskGroup struct { wg sync.WaitGroup // Counter for active goroutines. - // active is nonzero when the group is "active", meaning there has been at least one call to Run since the group + // active is true when the group is "active", meaning there has been at least one call to Run since the group // was created or the last Wait. // // Together active and errLock work as a kind of resettable sync.Once. The fast path reads active and only // acquires errLock if it discovers setup is needed. - active atomic.Uint32 + active atomic.Bool errLock sync.Mutex // Guards the fields below. err error // First captured error returned from Wait. @@ -56,7 +56,7 @@ func (g *TaskGroup) OnError(handler func(error) error) *TaskGroup { // so the [execute] function should include the interruption logic. func (g *TaskGroup) Run(execute func() error) { g.wg.Add(1) - if g.active.Load() == 0 { + if !g.active.Load() { g.activate() } go func() { @@ -82,9 +82,7 @@ func (g *TaskGroup) Wait() error { defer g.errLock.Unlock() // If the group is still active, deactivate it now. - if g.active.Load() != 0 { - g.active.Store(0) - } + g.active.CompareAndSwap(true, false) return g.err } @@ -93,9 +91,8 @@ func (g *TaskGroup) Wait() error { func (g *TaskGroup) activate() { g.errLock.Lock() defer g.errLock.Unlock() - if g.active.Load() == 0 { + if g.active.CompareAndSwap(false, true) { g.err = nil - g.active.Store(1) } } diff --git a/pkg/execution/taskgroup_test.go b/pkg/execution/taskgroup_test.go index bacedb3fc..ef3924a96 100644 --- a/pkg/execution/taskgroup_test.go +++ b/pkg/execution/taskgroup_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "math/rand/v2" - "runtime" "sync" "sync/atomic" "testing" @@ -83,8 +82,6 @@ func TestCancelPropagation(t *testing.T) { } }) } - runtime.Gosched() - <-time.After(500 * time.Microsecond) cancel() err := g.Wait() @@ -102,9 +99,6 @@ func TestCancelPropagation(t *testing.T) { } } - assert.NotZero(t, numOK) - assert.NotZero(t, numCanceled) - assert.NotZero(t, numOther) total := int(numOK) + numCanceled + numOther assert.Equal(t, numTasks, total) } @@ -119,7 +113,7 @@ func TestWaitingForFinish(t *testing.T) { select { case <-ctx.Done(): return work(50, nil)() - case <-time.After(randomDuration(60)): + case <-time.After(60 * time.Millisecond): return failure } } From 63a0305a294eb6b3b5038f044bbe977e889276e2 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Tue, 10 Dec 2024 19:29:05 +0400 Subject: [PATCH 10/30] Assertions added. Style fixed. --- pkg/execution/taskgroup_test.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/pkg/execution/taskgroup_test.go b/pkg/execution/taskgroup_test.go index ef3924a96..2678a0a04 100644 --- a/pkg/execution/taskgroup_test.go +++ b/pkg/execution/taskgroup_test.go @@ -95,7 +95,7 @@ func TestCancelPropagation(t *testing.T) { case errors.Is(e, errOther): numOther++ default: - require.FailNow(t, "unexpected error: %v", e) + require.FailNowf(t, "No error is expected", "unexpected error: %v", e) } } @@ -130,6 +130,8 @@ func TestWaitingForFinish(t *testing.T) { } func TestRegression(t *testing.T) { + defer goleak.VerifyNone(t) + t.Run("WaitRace", func(_ *testing.T) { ready := make(chan struct{}) var g execution.TaskGroup @@ -140,20 +142,26 @@ func TestRegression(t *testing.T) { var wg sync.WaitGroup wg.Add(2) - go func() { defer wg.Done(); _ = g.Wait() }() - go func() { defer wg.Done(); _ = g.Wait() }() + go func() { + defer wg.Done() + err := g.Wait() + require.NoError(t, err) + }() + go func() { + defer wg.Done() + err := g.Wait() + require.NoError(t, err) + }() close(ready) wg.Wait() }) t.Run("WaitUnstarted", func(t *testing.T) { - defer func() { - if x := recover(); x != nil { - t.Errorf("Unexpected panic: %v", x) - } - }() - var g execution.TaskGroup - _ = g.Wait() + require.NotPanics(t, func() { + var g execution.TaskGroup + err := g.Wait() + require.NoError(t, err) + }) }) } From 5219227d55a692e85565cdbdd5c7f74466419c1e Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Wed, 11 Dec 2024 12:18:56 +0400 Subject: [PATCH 11/30] Simplified closing and close logic in NetClient. Added logs on handshake rejection to clarify the reason of rejections. Added and used function to configure Session with list of Slog attributes. --- itests/clients/net_client.go | 71 ++++++++++++++------------------- pkg/networking/configuration.go | 8 ++++ 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/itests/clients/net_client.go b/itests/clients/net_client.go index dc7e9213a..74b7c457d 100644 --- a/itests/clients/net_client.go +++ b/itests/clients/net_client.go @@ -7,6 +7,7 @@ import ( "log/slog" "net" "sync" + "sync/atomic" "testing" "time" @@ -33,25 +34,22 @@ type NetClient struct { c *networking.Config s *networking.Session - closingLock sync.Mutex - closingFlag bool - closedLock sync.Mutex - closedFlag bool + closing atomic.Bool + closed sync.Once } func NewNetClient( ctx context.Context, t testing.TB, impl Implementation, port string, peers []proto.PeerInfo, ) *NetClient { n := networking.NewNetwork() - p := newProtocol(nil) + p := newProtocol(t, nil) h := newHandler(t, peers) log := slogt.New(t) conf := networking.NewConfig(p, h). WithLogger(log). WithWriteTimeout(networkTimeout). WithKeepAliveInterval(pingInterval). - WithSlogAttribute(slog.String("suite", t.Name())). - WithSlogAttribute(slog.String("impl", impl.String())) + WithSlogAttributes(slog.String("suite", t.Name()), slog.String("impl", impl.String())) conn, err := net.Dial("tcp", config.DefaultIP+":"+port) require.NoError(t, err, "failed to dial TCP to %s node", impl.String()) @@ -83,27 +81,18 @@ func (c *NetClient) SendHandshake() { } func (c *NetClient) SendMessage(m proto.Message) { - b, err := m.MarshalBinary() - require.NoError(c.t, err, "failed to marshal message to %s node at %q", c.impl.String(), c.s.RemoteAddr()) - _, err = c.s.Write(b) + _, err := m.WriteTo(c.s) require.NoError(c.t, err, "failed to send message to %s node at %q", c.impl.String(), c.s.RemoteAddr()) } func (c *NetClient) Close() { - c.t.Logf("Trying to close connection to %s node at %q", c.impl.String(), c.s.RemoteAddr().String()) - - c.closingLock.Lock() - c.closingFlag = true - c.closingLock.Unlock() - - c.closedLock.Lock() - defer c.closedLock.Unlock() - c.t.Logf("Closing connection to %s node at %q (%t)", c.impl.String(), c.s.RemoteAddr().String(), c.closedFlag) - if c.closedFlag { - return - } - _ = c.s.Close() - c.closedFlag = true + c.closed.Do(func() { + if c.closing.CompareAndSwap(false, true) { + c.t.Logf("Closing connection to %s node at %q", c.impl.String(), c.s.RemoteAddr().String()) + } + err := c.s.Close() + require.NoError(c.t, err, "failed to close session to %s node at %q", c.impl.String(), c.s.RemoteAddr()) + }) } func (c *NetClient) reconnect() { @@ -118,23 +107,18 @@ func (c *NetClient) reconnect() { c.SendHandshake() } -func (c *NetClient) closing() bool { - c.closingLock.Lock() - defer c.closingLock.Unlock() - return c.closingFlag -} - type protocol struct { + t testing.TB dropLock sync.Mutex drop map[proto.PeerMessageID]struct{} } -func newProtocol(drop []proto.PeerMessageID) *protocol { +func newProtocol(t testing.TB, drop []proto.PeerMessageID) *protocol { m := make(map[proto.PeerMessageID]struct{}) for _, id := range drop { m[id] = struct{}{} } - return &protocol{drop: m} + return &protocol{t: t, drop: m} } func (p *protocol) EmptyHandshake() networking.Handshake { @@ -158,6 +142,17 @@ func (p *protocol) IsAcceptableHandshake(h networking.Handshake) bool { // Reject nodes with incorrect network bytes, unsupported protocol versions, // or a zero nonce (indicating a self-connection). if hs.AppName != appName || hs.Version.Cmp(proto.ProtocolVersion()) < 0 || hs.NodeNonce == 0 { + p.t.Logf("Unacceptable handshake:") + if hs.AppName != appName { + p.t.Logf("\tinvalid application name %q, expected %q", hs.AppName, appName) + } + if hs.Version.Cmp(proto.ProtocolVersion()) < 0 { + p.t.Logf("\tinvalid application version %q should be equal or more than %q", + hs.Version, proto.ProtocolVersion()) + } + if hs.NodeNonce == 0 { + p.t.Logf("\tinvalid node nonce %d", hs.NodeNonce) + } return false } return true @@ -195,14 +190,8 @@ func (h *handler) OnReceive(s *networking.Session, data []byte) { case *proto.GetPeersMessage: h.t.Logf("Received GetPeersMessage from %q", s.RemoteAddr()) rpl := &proto.PeersMessage{Peers: h.peers} - bts, mErr := rpl.MarshalBinary() - if mErr != nil { // Fail test on marshal error. - h.t.Logf("Failed to marshal peers message: %v", mErr) - h.t.FailNow() - return - } - if _, wErr := s.Write(bts); wErr != nil { - h.t.Logf("Failed to send peers message: %v", wErr) + if _, sErr := rpl.WriteTo(s); sErr != nil { + h.t.Logf("Failed to send peers message: %v", sErr) h.t.FailNow() return } @@ -216,7 +205,7 @@ func (h *handler) OnHandshake(_ *networking.Session, _ networking.Handshake) { func (h *handler) OnClose(s *networking.Session) { h.t.Logf("Connection to %q was closed", s.RemoteAddr()) - if !h.client.closing() && h.client != nil { + if !h.client.closing.Load() && h.client != nil { h.client.reconnect() } } diff --git a/pkg/networking/configuration.go b/pkg/networking/configuration.go index be261d3c9..cda0e377e 100644 --- a/pkg/networking/configuration.go +++ b/pkg/networking/configuration.go @@ -53,6 +53,14 @@ func (c *Config) WithSlogAttribute(attr slog.Attr) *Config { return c } +// WithSlogAttributes adds given attributes to the slice of attributes. +func (c *Config) WithSlogAttributes(attrs ...slog.Attr) *Config { + for _, attr := range attrs { + c.attributes = append(c.attributes, attr) + } + return c +} + func (c *Config) WithKeepAliveDisabled() *Config { c.keepAlive = false return c From ff41cf77b4ce422d9866a3efd15261c2b2934a68 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Wed, 11 Dec 2024 12:40:13 +0400 Subject: [PATCH 12/30] Prepare for new timer in Go 1.23 Co-authored-by: Nikolay Eskov --- pkg/networking/timers.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/networking/timers.go b/pkg/networking/timers.go index 5f2d2949f..6492297a8 100644 --- a/pkg/networking/timers.go +++ b/pkg/networking/timers.go @@ -32,10 +32,11 @@ func (p *timerPool) Get() *time.Timer { } func (p *timerPool) Put(t *time.Timer) { - t.Stop() - select { - case <-t.C: - default: + if !t.Stop() { + select { + case <-t.C: + default: + } } p.p.Put(t) } From e2f697f94a184ae22382030cf9f2e6abc43851bf Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Thu, 12 Dec 2024 11:33:44 +0400 Subject: [PATCH 13/30] Move constant into function were it used. Proper error declaration. --- pkg/networking/network.go | 4 ++-- pkg/networking/timers.go | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/networking/network.go b/pkg/networking/network.go index 6715f24ed..df36a64fa 100644 --- a/pkg/networking/network.go +++ b/pkg/networking/network.go @@ -3,12 +3,12 @@ package networking import ( "context" "errors" - "fmt" "io" ) const Namespace = "NET" +// TODO: Consider special Error type for all [networking] errors. var ( // ErrInvalidConfigurationNoProtocol is used when the configuration has no protocol. ErrInvalidConfigurationNoProtocol = errors.New("invalid configuration: empty protocol") @@ -23,7 +23,7 @@ var ( ErrSessionShutdown = errors.New("session shutdown") // ErrConnectionWriteTimeout indicates that we hit the timeout writing to the underlying stream connection. - ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout") + ErrConnectionWriteTimeout = errors.New("connection write timeout") // ErrKeepAliveProtocolFailure is used when the protocol failed to provide a keep-alive message. ErrKeepAliveProtocolFailure = errors.New("protocol failed to provide a keep-alive message") diff --git a/pkg/networking/timers.go b/pkg/networking/timers.go index 6492297a8..9dd227c8a 100644 --- a/pkg/networking/timers.go +++ b/pkg/networking/timers.go @@ -5,13 +5,12 @@ import ( "time" ) -const initialTimerInterval = time.Hour * 1e6 - type timerPool struct { p *sync.Pool } func newTimerPool() *timerPool { + const initialTimerInterval = time.Hour * 1e6 return &timerPool{ p: &sync.Pool{ New: func() any { From f8326836c751888d67a990f3d5704c8460341a6d Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 13 Dec 2024 11:07:45 +0400 Subject: [PATCH 14/30] Better way to prevent from running multiple receiveLoops. Shutdown lock replaced with sync.Once. --- pkg/networking/session.go | 51 ++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index b0f760555..d0b1d8057 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -12,6 +12,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" "github.com/wavesplatform/gowaves/pkg/execution" @@ -36,10 +37,9 @@ type Session struct { sendLock sync.Mutex // Guards the sendCh. sendCh chan *sendPacket // sendCh is used to send data to the connection. - establishedLock sync.Mutex // Guards the established field. - established bool // Indicates that incoming Handshake was successfully accepted. - shutdownLock sync.Mutex // Guards the shutdown field. - shutdown bool // shutdown is used to safely close the Session. + receiving atomic.Bool // Indicates that receiveLoop already running. + established atomic.Bool // Indicates that incoming Handshake was successfully accepted. + shutdown sync.Once // shutdown is used to safely close the Session. } // NewSession is used to construct a new session. @@ -109,29 +109,24 @@ func (s *Session) RemoteAddr() net.Addr { // Close is used to close the session. It is safe to call Close multiple times from different goroutines, // subsequent calls do nothing. func (s *Session) Close() error { - s.shutdownLock.Lock() - defer s.shutdownLock.Unlock() - - if s.shutdown { - return nil // Fast path - session already closed. - } - s.shutdown = true - - s.logger.Debug("Closing session") - clErr := s.conn.Close() // Close the underlying connection. - if clErr != nil { - s.logger.Warn("Failed to close underlying connection", "error", clErr) - } - s.logger.Debug("Underlying connection closed") + var err error + s.shutdown.Do(func() { + s.logger.Debug("Closing session") + clErr := s.conn.Close() // Close the underlying connection. + if clErr != nil { + s.logger.Warn("Failed to close underlying connection", "error", clErr) + } + s.logger.Debug("Underlying connection closed") - s.cancel() // Cancel the underlying context to interrupt the loops. + s.cancel() // Cancel the underlying context to interrupt the loops. - s.logger.Debug("Waiting for loops to finish") - err := s.g.Wait() // Wait for loops to finish. + s.logger.Debug("Waiting for loops to finish") + err = s.g.Wait() // Wait for loops to finish. - err = errors.Join(err, clErr) // Combine loops finalization errors with connection close error. + err = errors.Join(err, clErr) // Combine loops finalization errors with connection close error. - s.logger.Debug("Session closed", "error", err) + s.logger.Debug("Session closed", "error", err) + }) return err } @@ -253,9 +248,9 @@ func (s *Session) sendLoop() error { // receiveLoop continues to receive data until a fatal error is encountered or underlying connection is closed. // Receive loop works after handshake and accepts only length-prepended messages. func (s *Session) receiveLoop() error { - s.establishedLock.Lock() // Prevents from running multiple receiveLoops. - defer s.establishedLock.Unlock() - + if !s.receiving.CompareAndSwap(false, true) { + return nil // Prevent running multiple receive loops. + } for { if err := s.receive(); err != nil { if errors.Is(err, ErrConnectionClosedOnRead) { @@ -268,7 +263,7 @@ func (s *Session) receiveLoop() error { } func (s *Session) receive() error { - if s.established { + if s.established.Load() { hdr := s.config.protocol.EmptyHeader() return s.readMessage(hdr) } @@ -295,7 +290,7 @@ func (s *Session) readHandshake() error { return ErrUnacceptableHandshake } // Handshake is acceptable, we can switch the session into established state. - s.established = true + s.established.Store(true) s.config.handler.OnHandshake(s, hs) return nil } From c2ad10151ba69b057219d2d855d56dbe0ae05327 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 13 Dec 2024 11:57:51 +0400 Subject: [PATCH 15/30] Better data emptyness checks. --- pkg/networking/session.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index d0b1d8057..aea9209f9 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -166,8 +166,8 @@ func (s *Session) waitForSend(data []byte) error { } dataCopy := func() { - if data == nil { - return // A nil data is ignored. + if len(data) == 0 { + return // An empty data is ignored. } // In the event of session shutdown or connection write timeout, we need to prevent `send` from reading @@ -213,7 +213,7 @@ func (s *Session) sendLoop() error { s.logger.Debug("Sending data to connection", "data", base64.StdEncoding.EncodeToString(packet.data)) packet.mu.Lock() - if packet.data != nil { + if len(packet.data) != 0 { // Copy the data into the buffer to avoid holding a mutex lock during the writing. _, err := dataBuf.Write(packet.data) if err != nil { From 3aa8a8586168ca4f33f80f48cacc78544d350b66 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 13 Dec 2024 12:02:33 +0400 Subject: [PATCH 16/30] Better read error handling. Co-authored-by: Nikolay Eskov --- pkg/networking/session.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index aea9209f9..9d1b1c73f 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -298,11 +298,14 @@ func (s *Session) readHandshake() error { func (s *Session) readMessage(hdr Header) error { // Read the header if _, err := hdr.ReadFrom(s.bufRead); err != nil { - if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "closed") || - strings.Contains(err.Error(), "reset by peer") || - strings.Contains(err.Error(), "broken pipe") { // In Docker network built on top of pipe, we get this error on close. + if errors.Is(err, io.EOF) { return ErrConnectionClosedOnRead } + if errMsg := err.Error(); strings.Contains(errMsg, "closed") || + strings.Contains(errMsg, "reset by peer") || + strings.Contains(errMsg, "broken pipe") { // In Docker network built on top of pipe, we get this error on close. + return errors.Join(ErrConnectionClosedOnRead, err) // Wrap the error with ErrConnectionClosedOnRead. + } s.logger.Error("Failed to read header", "error", err) return err } From c08baceed974af3ffa32fa15497223d8579b31ee Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 13 Dec 2024 12:05:36 +0400 Subject: [PATCH 17/30] Use constructor. --- pkg/networking/session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 9d1b1c73f..ac200d6ac 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -327,7 +327,7 @@ func (s *Session) readMessage(hdr Header) error { func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { // Wrap in a limited reader s.logger.Debug("Reading message payload", "len", hdr.PayloadLength()) - conn = &io.LimitedReader{R: conn, N: int64(hdr.PayloadLength())} + conn = io.LimitReader(conn, int64(hdr.PayloadLength())) // Copy into buffer s.receiveLock.Lock() From abced7f1f2a632c5d608a80eaa5aa0d93d620402 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Sat, 14 Dec 2024 17:59:54 +0400 Subject: [PATCH 18/30] Wrap heavy logging into log level checks. Fix data lock and data access order. --- pkg/networking/session.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index ac200d6ac..e7f3eba6c 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -152,7 +152,9 @@ func (s *Session) waitForSend(data []byte) error { timer.Reset(s.config.connectionWriteTimeout) defer s.tp.Put(timer) - s.logger.Debug("Sending data", "data", base64.StdEncoding.EncodeToString(data)) + if s.logger.Enabled(s.ctx, slog.LevelDebug) { + s.logger.Debug("Sending data", "data", base64.StdEncoding.EncodeToString(data)) + } ready := &sendPacket{data: data, err: errCh} select { case s.sendCh <- ready: @@ -210,9 +212,11 @@ func (s *Session) sendLoop() error { return s.ctx.Err() case packet := <-s.sendCh: - s.logger.Debug("Sending data to connection", - "data", base64.StdEncoding.EncodeToString(packet.data)) packet.mu.Lock() + if s.logger.Enabled(s.ctx, slog.LevelDebug) { + s.logger.Debug("Sending data to connection", + "data", base64.StdEncoding.EncodeToString(packet.data)) + } if len(packet.data) != 0 { // Copy the data into the buffer to avoid holding a mutex lock during the writing. _, err := dataBuf.Write(packet.data) @@ -352,7 +356,10 @@ func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { // We lock the buffer from modification on the time of invocation of OnReceive handler. // The slice of bytes passed into the handler is only valid for the duration of the handler invocation. // So inside the handler better deserialize message or make a copy of the bytes. - s.logger.Debug("Invoking OnReceive handler", "message", base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) + if s.logger.Enabled(s.ctx, slog.LevelDebug) { + s.logger.Debug("Invoking OnReceive handler", "message", + base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) + } s.config.handler.OnReceive(s, s.receiveBuffer.Bytes()) // Invoke OnReceive handler. s.receiveBuffer.Reset() // Reset the buffer for the next message. return nil From 57b9ffbd65388121fce922323a5ca9ac840903cc Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 16 Dec 2024 12:58:39 +0400 Subject: [PATCH 19/30] Session configuration accepts slog handler to set up logging. Discarding slog handler implemented and used instead of setting default slog logger. Checks on interval values added to Session constructor. --- itests/clients/net_client.go | 2 +- pkg/networking/configuration.go | 9 ++++----- pkg/networking/logging.go | 18 ++++++++++++++++++ pkg/networking/network.go | 6 ++++++ pkg/networking/session.go | 22 ++++++++++++++-------- pkg/networking/session_test.go | 2 +- 6 files changed, 44 insertions(+), 15 deletions(-) create mode 100644 pkg/networking/logging.go diff --git a/itests/clients/net_client.go b/itests/clients/net_client.go index 74b7c457d..821999517 100644 --- a/itests/clients/net_client.go +++ b/itests/clients/net_client.go @@ -46,7 +46,7 @@ func NewNetClient( h := newHandler(t, peers) log := slogt.New(t) conf := networking.NewConfig(p, h). - WithLogger(log). + WithSlogHandler(log.Handler()). WithWriteTimeout(networkTimeout). WithKeepAliveInterval(pingInterval). WithSlogAttributes(slog.String("suite", t.Name()), slog.String("impl", impl.String())) diff --git a/pkg/networking/configuration.go b/pkg/networking/configuration.go index cda0e377e..151bc8541 100644 --- a/pkg/networking/configuration.go +++ b/pkg/networking/configuration.go @@ -12,7 +12,7 @@ const ( // Config allows to set some parameters of the [Conn] or it's underlying connection. type Config struct { - logger *slog.Logger + slogHandler slog.Handler protocol Protocol handler Handler keepAlive bool @@ -25,7 +25,6 @@ type Config struct { // Other parameters are set to their default values. func NewConfig(p Protocol, h Handler) *Config { return &Config{ - logger: slog.Default(), protocol: p, handler: h, keepAlive: true, @@ -35,9 +34,9 @@ func NewConfig(p Protocol, h Handler) *Config { } } -// WithLogger sets the logger. -func (c *Config) WithLogger(logger *slog.Logger) *Config { - c.logger = logger +// WithSlogHandler sets the slog handler. +func (c *Config) WithSlogHandler(handler slog.Handler) *Config { + c.slogHandler = handler return c } diff --git a/pkg/networking/logging.go b/pkg/networking/logging.go new file mode 100644 index 000000000..94338ab31 --- /dev/null +++ b/pkg/networking/logging.go @@ -0,0 +1,18 @@ +package networking + +import ( + "context" + "log/slog" +) + +// TODO: Remove this file and the handler when the default [slog.DiscardHandler] will be introduced in +// Go version 1.24. See https://go-review.googlesource.com/c/go/+/626486. + +// discardingHandler is a logger that discards all log messages. +// It is used when no slog handler is provided in the [Config]. +type discardingHandler struct{} + +func (h discardingHandler) Enabled(context.Context, slog.Level) bool { return false } +func (h discardingHandler) Handle(context.Context, slog.Record) error { return nil } +func (h discardingHandler) WithAttrs([]slog.Attr) slog.Handler { return h } +func (h discardingHandler) WithGroup(string) slog.Handler { return h } diff --git a/pkg/networking/network.go b/pkg/networking/network.go index df36a64fa..f145b4cb1 100644 --- a/pkg/networking/network.go +++ b/pkg/networking/network.go @@ -16,6 +16,12 @@ var ( // ErrInvalidConfigurationNoHandler is used when the configuration has no handler. ErrInvalidConfigurationNoHandler = errors.New("invalid configuration: empty handler") + // ErrInvalidConfigurationNoKeepAliveInterval is used when the configuration has an invalid keep-alive interval. + ErrInvalidConfigurationNoKeepAliveInterval = errors.New("invalid configuration: invalid keep-alive interval value") + + // ErrInvalidConfigurationNoWriteTimeout is used when the configuration has an invalid write timeout. + ErrInvalidConfigurationNoWriteTimeout = errors.New("invalid configuration: invalid write timeout value") + // ErrUnacceptableHandshake is used when the handshake is not accepted. ErrUnacceptableHandshake = errors.New("handshake is not accepted") diff --git a/pkg/networking/session.go b/pkg/networking/session.go index e7f3eba6c..5723b1913 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -50,9 +50,16 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp if config.handler == nil { return nil, ErrInvalidConfigurationNoHandler } + if config.keepAlive && config.keepAliveInterval <= 0 { + return nil, ErrInvalidConfigurationNoKeepAliveInterval + } + if config.connectionWriteTimeout <= 0 { + return nil, ErrInvalidConfigurationNoWriteTimeout + } if tp == nil { return nil, ErrEmptyTimerPool } + sCtx, cancel := context.WithCancel(ctx) s := &Session{ g: execution.NewTaskGroup(suppressContextCancellationError), @@ -65,17 +72,16 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp sendCh: make(chan *sendPacket, 1), // TODO: Make the size of send channel configurable. } - attributes := []any{ - slog.String("namespace", Namespace), - slog.String("remote", s.RemoteAddr().String()), + if config.slogHandler == nil { + config.slogHandler = discardingHandler{} } - attributes = append(attributes, config.attributes...) - if config.logger == nil { - s.logger = slog.Default().With(attributes...) - } else { - s.logger = config.logger.With(attributes...) + sa := [...]any{ + slog.String("namespace", Namespace), + slog.String("remote", s.RemoteAddr().String()), } + attrs := append(sa[:], config.attributes...) + s.logger = slog.New(config.slogHandler).With(attrs...) s.g.Run(s.receiveLoop) s.g.Run(s.sendLoop) diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index 51fbd8369..cbeb9ad8e 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -385,7 +385,7 @@ func TestCloseParentContext(t *testing.T) { func testConfig(t testing.TB, p networking.Protocol, h networking.Handler, direction string) *networking.Config { log := slogt.New(t) return networking.NewConfig(p, h). - WithLogger(log). + WithSlogHandler(log.Handler()). WithWriteTimeout(1 * time.Second). WithKeepAliveDisabled(). WithSlogAttribute(slog.String("direction", direction)) From 14420bc474194ec3c46aa4eb8d44bc328ae80072 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 16 Dec 2024 17:22:42 +0400 Subject: [PATCH 20/30] Close error channel on sending data successfully. Better error channel passing. Reset receiving buffer by deffering. --- pkg/networking/session.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 5723b1913..2599c9dd4 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -192,8 +192,12 @@ func (s *Session) waitForSend(data []byte) error { } select { - case err := <-errCh: - s.logger.Debug("Data sent", "error", err) + case err, ok := <-errCh: + if !ok { + s.logger.Debug("Data sent successfully") + return nil // No error, data was sent successfully. + } + s.logger.Debug("Error sending data", "error", err) return err case <-s.ctx.Done(): dataCopy() @@ -249,8 +253,8 @@ func (s *Session) sendLoop() error { s.logger.Debug("Data written into connection") } - // No error, successful send - s.asyncSendErr(packet.err, nil) + // No error, close the channel. + close(packet.err) } } } @@ -347,6 +351,7 @@ func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { // Allocate the receiving buffer just-in-time to fit the full message. s.receiveBuffer = bytes.NewBuffer(make([]byte, 0, hdr.HeaderLength()+hdr.PayloadLength())) } + defer s.receiveBuffer.Reset() _, err := hdr.WriteTo(s.receiveBuffer) if err != nil { s.logger.Error("Failed to write header to receiving buffer", "error", err) @@ -367,7 +372,6 @@ func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) } s.config.handler.OnReceive(s, s.receiveBuffer.Bytes()) // Invoke OnReceive handler. - s.receiveBuffer.Reset() // Reset the buffer for the next message. return nil } @@ -399,11 +403,11 @@ func (s *Session) keepaliveLoop() error { type sendPacket struct { mu sync.Mutex // Protects data from unsafe reads. data []byte - err chan error + err chan<- error } // asyncSendErr is used to try an async send of an error. -func (s *Session) asyncSendErr(ch chan error, err error) { +func (s *Session) asyncSendErr(ch chan<- error, err error) { if ch == nil { return } From 412377f81c21605f9437ae677f1f6e963906d84f Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 16 Dec 2024 17:30:20 +0400 Subject: [PATCH 21/30] Better error handling while reading. Co-authored-by: Nikolay Eskov --- pkg/networking/session.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 2599c9dd4..04c1e6062 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -290,10 +290,13 @@ func (s *Session) readHandshake() error { hs := s.config.protocol.EmptyHandshake() _, err := hs.ReadFrom(s.bufRead) if err != nil { - if errors.Is(err, io.EOF) || strings.Contains(err.Error(), "closed") || - strings.Contains(err.Error(), "reset by peer") { + if errors.Is(err, io.EOF) { return ErrConnectionClosedOnRead } + if errMsg := err.Error(); strings.Contains(errMsg, "closed") || + strings.Contains(errMsg, "reset by peer") { + return errors.Join(ErrConnectionClosedOnRead, err) // Wrap the error with ErrConnectionClosedOnRead. + } s.logger.Error("Failed to read handshake from connection", "error", err) return err } From 52a893ea073ee78f994690be062ee3207b7637de Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 16 Dec 2024 17:52:45 +0400 Subject: [PATCH 22/30] Fine error assertions. --- pkg/networking/session_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index cbeb9ad8e..c89862531 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -125,7 +125,7 @@ func TestSessionTimeoutOnHandshake(t *testing.T) { // Send handshake to server, but writing will block because the clientConn is locked. n, err := clientSession.Write([]byte("hello")) - require.Error(t, err) + require.ErrorIs(t, err, networking.ErrConnectionWriteTimeout) assert.Equal(t, 0, n) runtime.Gosched() @@ -136,7 +136,7 @@ func TestSessionTimeoutOnHandshake(t *testing.T) { // Unlock "timeout" and close client. pc.writeBlocker.Unlock() err = clientSession.Close() - assert.Error(t, err) + assert.ErrorIs(t, err, io.ErrClosedPipe) } func TestSessionTimeoutOnMessage(t *testing.T) { @@ -199,7 +199,7 @@ func TestSessionTimeoutOnMessage(t *testing.T) { clientWG.Wait() // Wait for pipe to be locked. // On receiving handshake from server, send the message back to server. _, msgErr := clientSession.Write(encodeMessage("Hello session")) - require.Error(t, msgErr) + require.ErrorIs(t, msgErr, networking.ErrConnectionWriteTimeout) }) time.Sleep(1 * time.Second) // Let timeout occur. @@ -210,7 +210,7 @@ func TestSessionTimeoutOnMessage(t *testing.T) { pc.writeBlocker.Unlock() // Unlock the pipe. err = clientSession.Close() - assert.Error(t, err) // Expect error because connection to the server already closed. + assert.ErrorIs(t, err, io.ErrClosedPipe) // Expect error because connection to the server already closed. } func TestDoubleClose(t *testing.T) { @@ -305,13 +305,13 @@ func TestOnClosedByOtherSide(t *testing.T) { // Try to send message to server, but it will fail because server is already closed. time.Sleep(10 * time.Millisecond) // Wait for server to close. _, msgErr := clientSession.Write(encodeMessage("Hello session")) - require.Error(t, msgErr) + require.ErrorIs(t, msgErr, io.ErrClosedPipe) wg.Done() }) wg.Wait() // Wait for client to finish. err = clientSession.Close() - assert.Error(t, err) // Close reports the same error, because it was registered in the send loop. + assert.ErrorIs(t, err, io.ErrClosedPipe) // Close reports the same error, because it was registered in the send loop. } func TestCloseParentContext(t *testing.T) { @@ -370,7 +370,7 @@ func TestCloseParentContext(t *testing.T) { // Try to send message to server, but it will fail because server is already closed. time.Sleep(10 * time.Millisecond) // Wait for server to close. _, msgErr := clientSession.Write(encodeMessage("Hello session")) - require.Error(t, msgErr) + require.ErrorIs(t, msgErr, networking.ErrSessionShutdown) wg.Done() }) From 77633f3a39854a80a51d38b7effee9d97009e88b Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Mon, 16 Dec 2024 18:54:59 +0400 Subject: [PATCH 23/30] Fix blinking test. --- pkg/networking/session_test.go | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index c89862531..49759c03a 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -6,7 +6,6 @@ import ( "errors" "io" "log/slog" - "runtime" "sync" "testing" "time" @@ -98,6 +97,7 @@ func TestSessionTimeoutOnHandshake(t *testing.T) { defer goleak.VerifyNone(t) mockProtocol := netmocks.NewMockProtocol(t) + mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) clientHandler := netmocks.NewMockHandler(t) serverHandler := netmocks.NewMockHandler(t) @@ -110,33 +110,36 @@ func TestSessionTimeoutOnHandshake(t *testing.T) { clientSession, err := net.NewSession(ctx, clientConn, testConfig(t, mockProtocol, clientHandler, "client")) require.NoError(t, err) + clientHandler.On("OnClose", clientSession).Return() + serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) require.NoError(t, err) - - mockProtocol.On("EmptyHandshake").Return(&textHandshake{}, nil) serverHandler.On("OnClose", serverSession).Return() - clientHandler.On("OnClose", clientSession).Return() // Lock pc, ok := clientConn.(*pipeConn) require.True(t, ok) pc.writeBlocker.Lock() - runtime.Gosched() // Send handshake to server, but writing will block because the clientConn is locked. n, err := clientSession.Write([]byte("hello")) require.ErrorIs(t, err, networking.ErrConnectionWriteTimeout) assert.Equal(t, 0, n) - runtime.Gosched() - err = serverSession.Close() assert.NoError(t, err) + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + err = clientSession.Close() + assert.ErrorIs(t, err, io.ErrClosedPipe) + wg.Done() + }() + // Unlock "timeout" and close client. pc.writeBlocker.Unlock() - err = clientSession.Close() - assert.ErrorIs(t, err, io.ErrClosedPipe) + wg.Wait() } func TestSessionTimeoutOnMessage(t *testing.T) { From df77a57d016707204e4c071df0f9c89102c04374 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Thu, 19 Dec 2024 12:56:08 +0400 Subject: [PATCH 24/30] Better configuration handling. Co-authored-by: Nikolay Eskov --- pkg/networking/session.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 04c1e6062..2b7f6e72a 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -72,8 +72,9 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp sendCh: make(chan *sendPacket, 1), // TODO: Make the size of send channel configurable. } - if config.slogHandler == nil { - config.slogHandler = discardingHandler{} + slogHandler := config.slogHandler + if slogHandler == nil { + slogHandler = discardingHandler{} } sa := [...]any{ @@ -81,7 +82,7 @@ func newSession(ctx context.Context, config *Config, conn io.ReadWriteCloser, tp slog.String("remote", s.RemoteAddr().String()), } attrs := append(sa[:], config.attributes...) - s.logger = slog.New(config.slogHandler).With(attrs...) + s.logger = slog.New(slogHandler).With(attrs...) s.g.Run(s.receiveLoop) s.g.Run(s.sendLoop) From 7b3fffbcc59368cec2562d4ab9d08b34a9e250f5 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Thu, 19 Dec 2024 15:34:32 +0400 Subject: [PATCH 25/30] Fixed blinking test TestCloseParentContext. Wait group added to wait for client to finish sending handshake. Better wait groups naming. --- pkg/networking/session_test.go | 46 +++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index 49759c03a..da83bcaff 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -330,7 +330,6 @@ func TestCloseParentContext(t *testing.T) { serverHandler := netmocks.NewMockHandler(t) ctx, cancel := context.WithCancel(context.Background()) - defer cancel() clientConn, serverConn := testConnPipe() net := networking.NewNetwork() @@ -340,44 +339,51 @@ func TestCloseParentContext(t *testing.T) { serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) require.NoError(t, err) - var closeWG sync.WaitGroup - closeWG.Add(1) + clientWG := new(sync.WaitGroup) + clientWG.Add(1) // Wait for client to send Handshake to server. - var wg sync.WaitGroup - wg.Add(2) + serverWG := new(sync.WaitGroup) + serverWG.Add(1) // Wait for server to send Handshake to client, after that we will close the parent context. + + testWG := new(sync.WaitGroup) + testWG.Add(2) // Wait for both client and server to finish. serverHandler.On("OnClose", serverSession).Return() sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() sc1.Run(func(_ mock.Arguments) { + clientWG.Wait() // Wait for client to send handshake, start replying with Handshake only after that. n, wErr := serverSession.Write([]byte("hello")) assert.NoError(t, wErr) assert.Equal(t, 5, n) go func() { - closeWG.Wait() // Wait for client to receive server handshake. - cancel() // Close parent context. - wg.Done() + serverWG.Wait() // Wait for client to receive server handshake. + cancel() // Close parent context. + testWG.Done() }() }) clientHandler.On("OnClose", clientSession).Return() - // Send handshake to server. - n, err := clientSession.Write([]byte("hello")) - require.NoError(t, err) - assert.Equal(t, 5, n) - cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() cs1.Run(func(_ mock.Arguments) { // On receiving handshake from server, signal to close the server. - closeWG.Done() - // Try to send message to server, but it will fail because server is already closed. - time.Sleep(10 * time.Millisecond) // Wait for server to close. - _, msgErr := clientSession.Write(encodeMessage("Hello session")) - require.ErrorIs(t, msgErr, networking.ErrSessionShutdown) - wg.Done() + serverWG.Done() + go func() { + // Try to send message to server, but it will fail because server is already closed. + time.Sleep(10 * time.Millisecond) // Wait for server to close. + _, msgErr := clientSession.Write(encodeMessage("Hello session")) + require.ErrorIs(t, msgErr, networking.ErrSessionShutdown) + testWG.Done() + }() }) - wg.Wait() // Wait for client to finish. + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + clientWG.Done() // Signal that handshake was sent to server. + + testWG.Wait() // Wait for all interactions to finish. err = clientSession.Close() assert.NoError(t, err) From d2e2646eb9827bb883d382267d1c13fe0e61d153 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Thu, 19 Dec 2024 16:54:17 +0400 Subject: [PATCH 26/30] Better test workflow. Better wait group naming. --- pkg/networking/session_test.go | 59 ++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index da83bcaff..fddc2c38c 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -170,42 +170,51 @@ func TestSessionTimeoutOnMessage(t *testing.T) { serverHandler.On("OnClose", serverSession).Return() - var serverWG sync.WaitGroup - var clientWG sync.WaitGroup - serverWG.Add(1) - clientWG.Add(1) - go func() { - sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() - sc1.Run(func(_ mock.Arguments) { - n, wErr := serverSession.Write([]byte("hello")) - require.NoError(t, wErr) - assert.Equal(t, 5, n) - serverWG.Done() - }) - serverWG.Wait() // Wait for finishing handshake before closing the pipe. + clientWG := new(sync.WaitGroup) + clientWG.Add(1) // Wait for client to send Handshake to server. - // Lock pipe after replying with the handshake from server. - pc.writeBlocker.Lock() - clientWG.Done() // Signal that pipe is locked. - }() + serverWG := new(sync.WaitGroup) + serverWG.Add(1) // Wait for server to reply with Handshake to client. - serverHandler.On("OnClose", serverSession).Return() - clientHandler.On("OnClose", clientSession).Return() + pipeWG := new(sync.WaitGroup) + pipeWG.Add(1) // Wait for pipe to be locked. - // Send handshake to server. - n, err := clientSession.Write([]byte("hello")) - require.NoError(t, err) - assert.Equal(t, 5, n) + testWG := new(sync.WaitGroup) + testWG.Add(1) // Wait for client fail by timeout. + + serverHandler.On("OnClose", serverSession).Return() + sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() + sc1.Run(func(_ mock.Arguments) { + clientWG.Wait() // Wait for client to send handshake, start replying with Handshake only after that. + n, wErr := serverSession.Write([]byte("hello")) + require.NoError(t, wErr) + assert.Equal(t, 5, n) + serverWG.Done() + }) + clientHandler.On("OnClose", clientSession).Return() cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() cs1.Run(func(_ mock.Arguments) { - clientWG.Wait() // Wait for pipe to be locked. + pipeWG.Wait() // Wait for pipe to be locked. // On receiving handshake from server, send the message back to server. _, msgErr := clientSession.Write(encodeMessage("Hello session")) require.ErrorIs(t, msgErr, networking.ErrConnectionWriteTimeout) + testWG.Done() }) - time.Sleep(1 * time.Second) // Let timeout occur. + go func() { + serverWG.Wait() // Wait for finishing handshake before closing the pipe. + pc.writeBlocker.Lock() // Lock pipe after replying with the handshake from server. + pipeWG.Done() // Signal that pipe is locked. + }() + + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + clientWG.Done() // Signal that handshake was sent to server. + + testWG.Wait() err = serverSession.Close() assert.NoError(t, err) // Expect no error on the server side. From 9599840d330924fc717ba7bad4d1ecdfc0579e8a Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 20 Dec 2024 14:08:42 +0400 Subject: [PATCH 27/30] Fix deadlock in test by introducing wait group instead of sleep. --- pkg/networking/session_test.go | 39 ++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index fddc2c38c..c39220027 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -282,46 +282,53 @@ func TestOnClosedByOtherSide(t *testing.T) { serverSession, err := net.NewSession(ctx, serverConn, testConfig(t, mockProtocol, serverHandler, "server")) require.NoError(t, err) - var closeWG sync.WaitGroup - closeWG.Add(1) + clientWG := new(sync.WaitGroup) + clientWG.Add(1) // Wait for client to send Handshake to server. + + serverWG := new(sync.WaitGroup) + serverWG.Add(1) // Wait for server to send Handshake to client, after that close the connection from server. + closeWG := new(sync.WaitGroup) + closeWG.Add(1) // Wait for server to close the connection. - var wg sync.WaitGroup - wg.Add(2) + testWG := new(sync.WaitGroup) + testWG.Add(2) // Wait for both client and server to finish. serverHandler.On("OnClose", serverSession).Return() sc1 := serverHandler.On("OnHandshake", serverSession, &textHandshake{v: "hello"}).Once().Return() sc1.Run(func(_ mock.Arguments) { + clientWG.Wait() // Wait for client to send handshake, start replying with Handshake only after that. n, wErr := serverSession.Write([]byte("hello")) assert.NoError(t, wErr) assert.Equal(t, 5, n) go func() { // Close server after client received the handshake from server. - closeWG.Wait() // Wait for client to receive server handshake. + serverWG.Wait() // Wait for client to receive server handshake. clErr := serverSession.Close() assert.NoError(t, clErr) - wg.Done() + closeWG.Done() + testWG.Done() }() }) clientHandler.On("OnClose", clientSession).Return() - - // Send handshake to server. - n, err := clientSession.Write([]byte("hello")) - require.NoError(t, err) - assert.Equal(t, 5, n) - cs1 := clientHandler.On("OnHandshake", clientSession, &textHandshake{v: "hello"}).Once().Return() cs1.Run(func(_ mock.Arguments) { // On receiving handshake from server, signal to close the server. - closeWG.Done() + serverWG.Done() // Try to send message to server, but it will fail because server is already closed. - time.Sleep(10 * time.Millisecond) // Wait for server to close. + closeWG.Wait() // Wait for server to close. _, msgErr := clientSession.Write(encodeMessage("Hello session")) require.ErrorIs(t, msgErr, io.ErrClosedPipe) - wg.Done() + testWG.Done() }) - wg.Wait() // Wait for client to finish. + // Send handshake to server. + n, err := clientSession.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + clientWG.Done() // Signal that handshake was sent to server. + + testWG.Wait() // Wait for client to finish. err = clientSession.Close() assert.ErrorIs(t, err, io.ErrClosedPipe) // Close reports the same error, because it was registered in the send loop. } From 63ade4e915e9786fadbad50cbe91006ba34c1157 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 27 Dec 2024 19:34:05 +0400 Subject: [PATCH 28/30] Internal sendPacket reimplemented using io.Reader. Data restoration function removed. Handler's OnReceive use io.Reader to pass received data. Tests updated. Mocks regenerated. --- pkg/networking/handler.go | 4 ++- pkg/networking/mocks/handler.go | 20 ++++++----- pkg/networking/mocks/header.go | 6 ++-- pkg/networking/mocks/protocol.go | 8 ++--- pkg/networking/session.go | 57 ++++++++++---------------------- pkg/networking/session_test.go | 6 ++-- 6 files changed, 42 insertions(+), 59 deletions(-) diff --git a/pkg/networking/handler.go b/pkg/networking/handler.go index 2f4f62587..81b81cb39 100644 --- a/pkg/networking/handler.go +++ b/pkg/networking/handler.go @@ -1,9 +1,11 @@ package networking +import "io" + // Handler is an interface for handling new messages, handshakes and session close events. type Handler interface { // OnReceive fired on new message received. - OnReceive(*Session, []byte) + OnReceive(*Session, io.Reader) // OnHandshake fired on new Handshake received. OnHandshake(*Session, Handshake) diff --git a/pkg/networking/mocks/handler.go b/pkg/networking/mocks/handler.go index d7ba29dd3..a11fdc547 100644 --- a/pkg/networking/mocks/handler.go +++ b/pkg/networking/mocks/handler.go @@ -1,8 +1,10 @@ -// Code generated by mockery v2.46.3. DO NOT EDIT. +// Code generated by mockery v2.50.1. DO NOT EDIT. package networking import ( + io "io" + mock "github.com/stretchr/testify/mock" networking "github.com/wavesplatform/gowaves/pkg/networking" ) @@ -49,7 +51,7 @@ func (_c *MockHandler_OnClose_Call) Return() *MockHandler_OnClose_Call { } func (_c *MockHandler_OnClose_Call) RunAndReturn(run func(*networking.Session)) *MockHandler_OnClose_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -83,12 +85,12 @@ func (_c *MockHandler_OnHandshake_Call) Return() *MockHandler_OnHandshake_Call { } func (_c *MockHandler_OnHandshake_Call) RunAndReturn(run func(*networking.Session, networking.Handshake)) *MockHandler_OnHandshake_Call { - _c.Call.Return(run) + _c.Run(run) return _c } // OnReceive provides a mock function with given fields: _a0, _a1 -func (_m *MockHandler) OnReceive(_a0 *networking.Session, _a1 []byte) { +func (_m *MockHandler) OnReceive(_a0 *networking.Session, _a1 io.Reader) { _m.Called(_a0, _a1) } @@ -99,14 +101,14 @@ type MockHandler_OnReceive_Call struct { // OnReceive is a helper method to define mock.On call // - _a0 *networking.Session -// - _a1 []byte +// - _a1 io.Reader func (_e *MockHandler_Expecter) OnReceive(_a0 interface{}, _a1 interface{}) *MockHandler_OnReceive_Call { return &MockHandler_OnReceive_Call{Call: _e.mock.On("OnReceive", _a0, _a1)} } -func (_c *MockHandler_OnReceive_Call) Run(run func(_a0 *networking.Session, _a1 []byte)) *MockHandler_OnReceive_Call { +func (_c *MockHandler_OnReceive_Call) Run(run func(_a0 *networking.Session, _a1 io.Reader)) *MockHandler_OnReceive_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*networking.Session), args[1].([]byte)) + run(args[0].(*networking.Session), args[1].(io.Reader)) }) return _c } @@ -116,8 +118,8 @@ func (_c *MockHandler_OnReceive_Call) Return() *MockHandler_OnReceive_Call { return _c } -func (_c *MockHandler_OnReceive_Call) RunAndReturn(run func(*networking.Session, []byte)) *MockHandler_OnReceive_Call { - _c.Call.Return(run) +func (_c *MockHandler_OnReceive_Call) RunAndReturn(run func(*networking.Session, io.Reader)) *MockHandler_OnReceive_Call { + _c.Run(run) return _c } diff --git a/pkg/networking/mocks/header.go b/pkg/networking/mocks/header.go index 1de986214..eabade26a 100644 --- a/pkg/networking/mocks/header.go +++ b/pkg/networking/mocks/header.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.3. DO NOT EDIT. +// Code generated by mockery v2.50.1. DO NOT EDIT. package networking @@ -21,7 +21,7 @@ func (_m *MockHeader) EXPECT() *MockHeader_Expecter { return &MockHeader_Expecter{mock: &_m.Mock} } -// HeaderLength provides a mock function with given fields: +// HeaderLength provides a mock function with no fields func (_m *MockHeader) HeaderLength() uint32 { ret := _m.Called() @@ -66,7 +66,7 @@ func (_c *MockHeader_HeaderLength_Call) RunAndReturn(run func() uint32) *MockHea return _c } -// PayloadLength provides a mock function with given fields: +// PayloadLength provides a mock function with no fields func (_m *MockHeader) PayloadLength() uint32 { ret := _m.Called() diff --git a/pkg/networking/mocks/protocol.go b/pkg/networking/mocks/protocol.go index 19afa9cff..dc30f5d74 100644 --- a/pkg/networking/mocks/protocol.go +++ b/pkg/networking/mocks/protocol.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.3. DO NOT EDIT. +// Code generated by mockery v2.50.1. DO NOT EDIT. package networking @@ -20,7 +20,7 @@ func (_m *MockProtocol) EXPECT() *MockProtocol_Expecter { return &MockProtocol_Expecter{mock: &_m.Mock} } -// EmptyHandshake provides a mock function with given fields: +// EmptyHandshake provides a mock function with no fields func (_m *MockProtocol) EmptyHandshake() networking.Handshake { ret := _m.Called() @@ -67,7 +67,7 @@ func (_c *MockProtocol_EmptyHandshake_Call) RunAndReturn(run func() networking.H return _c } -// EmptyHeader provides a mock function with given fields: +// EmptyHeader provides a mock function with no fields func (_m *MockProtocol) EmptyHeader() networking.Header { ret := _m.Called() @@ -206,7 +206,7 @@ func (_c *MockProtocol_IsAcceptableMessage_Call) RunAndReturn(run func(networkin return _c } -// Ping provides a mock function with given fields: +// Ping provides a mock function with no fields func (_m *MockProtocol) Ping() ([]byte, error) { ret := _m.Called() diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 2b7f6e72a..6cfc79e73 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -162,9 +162,8 @@ func (s *Session) waitForSend(data []byte) error { if s.logger.Enabled(s.ctx, slog.LevelDebug) { s.logger.Debug("Sending data", "data", base64.StdEncoding.EncodeToString(data)) } - ready := &sendPacket{data: data, err: errCh} select { - case s.sendCh <- ready: + case s.sendCh <- newSendPacket(data, errCh): s.logger.Debug("Data written into send channel") case <-s.ctx.Done(): s.logger.Debug("Session shutdown while sending data") @@ -174,24 +173,6 @@ func (s *Session) waitForSend(data []byte) error { return ErrConnectionWriteTimeout } - dataCopy := func() { - if len(data) == 0 { - return // An empty data is ignored. - } - - // In the event of session shutdown or connection write timeout, we need to prevent `send` from reading - // the body buffer after returning from this function since the caller may re-use the underlying array. - ready.mu.Lock() - defer ready.mu.Unlock() - - if ready.data == nil { - return // data was already copied in `send`. - } - newData := make([]byte, len(data)) - copy(newData, data) - ready.data = newData - } - select { case err, ok := <-errCh: if !ok { @@ -201,11 +182,9 @@ func (s *Session) waitForSend(data []byte) error { s.logger.Debug("Error sending data", "error", err) return err case <-s.ctx.Done(): - dataCopy() s.logger.Debug("Session shutdown while waiting send error") return ErrSessionShutdown case <-timer.C: - dataCopy() s.logger.Debug("Connection write timeout while waiting send error") return ErrConnectionWriteTimeout } @@ -224,22 +203,16 @@ func (s *Session) sendLoop() error { case packet := <-s.sendCh: packet.mu.Lock() + _, rErr := dataBuf.ReadFrom(packet.r) + if rErr != nil { + packet.mu.Unlock() + s.logger.Error("Failed to copy data into buffer", "error", rErr) + s.asyncSendErr(packet.err, rErr) + return rErr + } if s.logger.Enabled(s.ctx, slog.LevelDebug) { s.logger.Debug("Sending data to connection", - "data", base64.StdEncoding.EncodeToString(packet.data)) - } - if len(packet.data) != 0 { - // Copy the data into the buffer to avoid holding a mutex lock during the writing. - _, err := dataBuf.Write(packet.data) - if err != nil { - packet.data = nil - packet.mu.Unlock() - s.logger.Error("Failed to copy data into buffer", "error", err) - s.asyncSendErr(packet.err, err) - return err // TODO: Do we need to return here? - } - s.logger.Debug("Data copied into buffer") - packet.data = nil + "data", base64.StdEncoding.EncodeToString(dataBuf.Bytes())) } packet.mu.Unlock() @@ -375,7 +348,7 @@ func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { s.logger.Debug("Invoking OnReceive handler", "message", base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) } - s.config.handler.OnReceive(s, s.receiveBuffer.Bytes()) // Invoke OnReceive handler. + s.config.handler.OnReceive(s, bytes.NewReader(s.receiveBuffer.Bytes())) // Invoke OnReceive handler. return nil } @@ -405,9 +378,13 @@ func (s *Session) keepaliveLoop() error { // sendPacket is used to send data. type sendPacket struct { - mu sync.Mutex // Protects data from unsafe reads. - data []byte - err chan<- error + mu sync.Mutex // Protects data from unsafe reads. + r io.Reader + err chan<- error +} + +func newSendPacket(data []byte, ch chan<- error) *sendPacket { + return &sendPacket{r: bytes.NewReader(data), err: ch} } // asyncSendErr is used to try an async send of an error. diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index c39220027..745ee194b 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -1,6 +1,7 @@ package networking_test import ( + "bytes" "context" "encoding/binary" "errors" @@ -55,7 +56,8 @@ func TestSuccessfulSession(t *testing.T) { require.NoError(t, wErr) assert.Equal(t, 5, n) }) - sc2 := serverHandler.On("OnReceive", ss, encodeMessage("Hello session")).Once().Return() + sc2 := serverHandler.On("OnReceive", ss, bytes.NewReader(encodeMessage("Hello session"))). + Once().Return() sc2.NotBefore(sc1). Run(func(_ mock.Arguments) { n, wErr := ss.Write(encodeMessage("Hi")) @@ -73,7 +75,7 @@ func TestSuccessfulSession(t *testing.T) { require.NoError(t, wErr) assert.Equal(t, 17, n) }) - cl2 := clientHandler.On("OnReceive", cs, encodeMessage("Hi")).Once().Return() + cl2 := clientHandler.On("OnReceive", cs, bytes.NewReader(encodeMessage("Hi"))).Once().Return() cl2.NotBefore(cl1). Run(func(_ mock.Arguments) { cWG.Done() From 78579ca4c48ece31fdbffa4c360b5fc919332c46 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Fri, 27 Dec 2024 19:54:49 +0400 Subject: [PATCH 29/30] Itest network client handler updated. --- itests/clients/net_client.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/itests/clients/net_client.go b/itests/clients/net_client.go index 821999517..e6d984adb 100644 --- a/itests/clients/net_client.go +++ b/itests/clients/net_client.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/base64" + "io" "log/slog" "net" "sync" @@ -179,7 +180,13 @@ func newHandler(t testing.TB, peers []proto.PeerInfo) *handler { return &handler{t: t, peers: peers} } -func (h *handler) OnReceive(s *networking.Session, data []byte) { +func (h *handler) OnReceive(s *networking.Session, r io.Reader) { + data, err := io.ReadAll(r) + if err != nil { + h.t.Logf("Failed to read message from %q: %v", s.RemoteAddr(), err) + h.t.FailNow() + return + } msg, err := proto.UnmarshalMessage(data) if err != nil { // Fail test on unmarshal error. h.t.Logf("Failed to unmarshal message from bytes: %q", base64.StdEncoding.EncodeToString(data)) From de29ff8ccacc2328cce3182ad1dbfe3a122f0888 Mon Sep 17 00:00:00 2001 From: Alexey Kiselev Date: Sat, 28 Dec 2024 14:23:25 +0400 Subject: [PATCH 30/30] Changed the way OnReceive passes the receiveBuffer. Test updated. --- pkg/networking/session.go | 2 +- pkg/networking/session_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/networking/session.go b/pkg/networking/session.go index 6cfc79e73..bdab4bdd4 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -348,7 +348,7 @@ func (s *Session) readMessagePayload(hdr Header, conn io.Reader) error { s.logger.Debug("Invoking OnReceive handler", "message", base64.StdEncoding.EncodeToString(s.receiveBuffer.Bytes())) } - s.config.handler.OnReceive(s, bytes.NewReader(s.receiveBuffer.Bytes())) // Invoke OnReceive handler. + s.config.handler.OnReceive(s, s.receiveBuffer) // Invoke OnReceive handler. return nil } diff --git a/pkg/networking/session_test.go b/pkg/networking/session_test.go index 745ee194b..c30bfac81 100644 --- a/pkg/networking/session_test.go +++ b/pkg/networking/session_test.go @@ -56,7 +56,7 @@ func TestSuccessfulSession(t *testing.T) { require.NoError(t, wErr) assert.Equal(t, 5, n) }) - sc2 := serverHandler.On("OnReceive", ss, bytes.NewReader(encodeMessage("Hello session"))). + sc2 := serverHandler.On("OnReceive", ss, bytes.NewBuffer(encodeMessage("Hello session"))). Once().Return() sc2.NotBefore(sc1). Run(func(_ mock.Arguments) { @@ -75,7 +75,7 @@ func TestSuccessfulSession(t *testing.T) { require.NoError(t, wErr) assert.Equal(t, 17, n) }) - cl2 := clientHandler.On("OnReceive", cs, bytes.NewReader(encodeMessage("Hi"))).Once().Return() + cl2 := clientHandler.On("OnReceive", cs, bytes.NewBuffer(encodeMessage("Hi"))).Once().Return() cl2.NotBefore(cl1). Run(func(_ mock.Arguments) { cWG.Done()