diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go new file mode 100644 index 00000000..4266c236 --- /dev/null +++ b/internal/bpool/bpool.go @@ -0,0 +1,24 @@ +package bpool + +import ( + "bytes" + "sync" +) + +var bpool sync.Pool + +// Get returns a buffer from the pool or creates a new one if +// the pool is empty. +func Get() *bytes.Buffer { + b, ok := bpool.Get().(*bytes.Buffer) + if !ok { + b = &bytes.Buffer{} + } + return b +} + +// Put returns a buffer into the pool. +func Put(b *bytes.Buffer) { + b.Reset() + bpool.Put(b) +} diff --git a/internal/bpool/bpool_test.go b/internal/bpool/bpool_test.go new file mode 100644 index 00000000..2b302a47 --- /dev/null +++ b/internal/bpool/bpool_test.go @@ -0,0 +1,47 @@ +package bpool + +import ( + "strconv" + "sync" + "testing" +) + +func BenchmarkSyncPool(b *testing.B) { + sizes := []int{ + 2, + 16, + 32, + 64, + 128, + 256, + 512, + 4096, + 16384, + } + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + b.Run("allocate", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := make([]byte, size) + _ = buf + } + }) + b.Run("pool", func(b *testing.B) { + b.ReportAllocs() + + p := sync.Pool{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf := p.Get() + if buf == nil { + buf = make([]byte, size) + } + + p.Put(buf) + } + }) + }) + } +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 994ffad1..fdde2e06 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -8,6 +8,7 @@ import ( "golang.org/x/xerrors" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/bpool" ) // Read reads a json message from c into v. @@ -22,7 +23,7 @@ func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { } func read(ctx context.Context, c *websocket.Conn, v interface{}) error { - typ, b, err := c.Read(ctx) + typ, r, err := c.Reader(ctx) if err != nil { return err } @@ -32,7 +33,17 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) } - err = json.Unmarshal(b, v) + b := bpool.Get() + defer func() { + bpool.Put(b) + }() + + _, err = b.ReadFrom(r) + if err != nil { + return err + } + + err = json.Unmarshal(b.Bytes(), v) if err != nil { return xerrors.Errorf("failed to unmarshal json: %w", err) } diff --git a/wspb/wspb.go b/wspb/wspb.go index e6c91693..49c2ae54 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -2,12 +2,15 @@ package wspb import ( + "bytes" "context" + "sync" "github.com/golang/protobuf/proto" "golang.org/x/xerrors" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/bpool" ) // Read reads a protobuf message from c into v. @@ -21,7 +24,7 @@ func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { } func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - typ, b, err := c.Read(ctx) + typ, r, err := c.Reader(ctx) if err != nil { return err } @@ -31,7 +34,17 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) } - err = proto.Unmarshal(b, v) + b := bpool.Get() + defer func() { + bpool.Put(b) + }() + + _, err = b.ReadFrom(r) + if err != nil { + return err + } + + err = proto.Unmarshal(b.Bytes(), v) if err != nil { return xerrors.Errorf("failed to unmarshal protobuf: %w", err) } @@ -49,11 +62,19 @@ func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { return nil } +var writeBufPool sync.Pool + func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - b, err := proto.Marshal(v) + b := bpool.Get() + pb := proto.NewBuffer(b.Bytes()) + defer func() { + bpool.Put(bytes.NewBuffer(pb.Bytes())) + }() + + err := pb.Marshal(v) if err != nil { return xerrors.Errorf("failed to marshal protobuf: %w", err) } - return c.Write(ctx, websocket.MessageBinary, b) + return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) }