Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

producer: Integrate context.Context with publishing #365

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"compress/flate"
"context"
"crypto/tls"
"encoding/json"
"errors"
Expand Down Expand Up @@ -119,8 +120,7 @@ func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
// Output(calldepth int, s string)
func (c *Conn) SetLogger(l logger, lvl LogLevel, format string) {
c.logGuard.Lock()
defer c.logGuard.Unlock()
Expand Down Expand Up @@ -171,12 +171,18 @@ func (c *Conn) getLogLevel() LogLevel {
// Connect dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) Connect() (*IdentifyResponse, error) {
ctx := context.Background()
return c.ConnectWithContext(ctx)
}

func (c *Conn) ConnectWithContext(ctx context.Context) (*IdentifyResponse, error) {
dialer := &net.Dialer{
LocalAddr: c.config.LocalAddr,
Timeout: c.config.DialTimeout,
}

conn, err := dialer.Dial("tcp", c.addr)
// the timeout used is smallest of dialer.Timeout (config.DialTimeout) or context timeout
conn, err := dialer.DialContext(ctx, "tcp", c.addr)
if err != nil {
return nil, err
}
Expand All @@ -190,7 +196,7 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err)
}

resp, err := c.identify()
resp, err := c.identify(ctx)
if err != nil {
return nil, err
}
Expand All @@ -200,7 +206,7 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
c.log(LogLevelError, "Auth Required")
return nil, errors.New("Auth Required")
}
err := c.auth(c.config.AuthSecret)
err := c.auth(ctx, c.config.AuthSecret)
if err != nil {
c.log(LogLevelError, "Auth Failed %s", err)
return nil, err
Expand Down Expand Up @@ -291,13 +297,26 @@ func (c *Conn) Write(p []byte) (int, error) {
// WriteCommand is a goroutine safe method to write a Command
// to this connection, and flush.
func (c *Conn) WriteCommand(cmd *Command) error {
ctx := context.Background()
return c.WriteCommandWithContext(ctx, cmd)
}

func (c *Conn) WriteCommandWithContext(ctx context.Context, cmd *Command) error {
c.mtx.Lock()

_, err := cmd.WriteTo(c)
if err != nil {
var err error
select {
case <-ctx.Done():
c.mtx.Unlock()
return ctx.Err()
default:
_, err = cmd.WriteTo(c)
if err != nil {
goto exit
}
err = c.Flush()
goto exit
}
err = c.Flush()

exit:
c.mtx.Unlock()
Expand All @@ -320,7 +339,7 @@ func (c *Conn) Flush() error {
return nil
}

func (c *Conn) identify() (*IdentifyResponse, error) {
func (c *Conn) identify(ctx context.Context) (*IdentifyResponse, error) {
ci := make(map[string]interface{})
ci["client_id"] = c.config.ClientID
ci["hostname"] = c.config.Hostname
Expand Down Expand Up @@ -350,7 +369,7 @@ func (c *Conn) identify() (*IdentifyResponse, error) {
return nil, ErrIdentify{err.Error()}
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
Expand Down Expand Up @@ -479,13 +498,13 @@ func (c *Conn) upgradeSnappy() error {
return nil
}

func (c *Conn) auth(secret string) error {
func (c *Conn) auth(ctx context.Context, secret string) error {
cmd, err := Auth(secret)
if err != nil {
return err
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return err
}
Expand Down
97 changes: 79 additions & 18 deletions producer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nsq

import (
"context"
"fmt"
"log"
"os"
Expand All @@ -15,8 +16,10 @@ type producerConn interface {
SetLoggerLevel(LogLevel)
SetLoggerForLevel(logger, LogLevel, string)
Connect() (*IdentifyResponse, error)
ConnectWithContext(context.Context) (*IdentifyResponse, error)
Close() error
WriteCommand(*Command) error
WriteCommandWithContext(context.Context, *Command) error
}

// Producer is a high-level type to publish to NSQ.
Expand Down Expand Up @@ -53,6 +56,7 @@ type Producer struct {
// to retrieve metadata about the command after the
// response is received.
type ProducerTransaction struct {
ctx context.Context
cmd *Command
doneChan chan *ProducerTransaction
Error error // the error (or nil) of the publish command
Expand Down Expand Up @@ -105,23 +109,27 @@ func NewProducer(addr string, config *Config) (*Producer, error) {
// configured correctly, rather than relying on the lazy "connect on Publish"
// behavior of a Producer.
func (w *Producer) Ping() error {
ctx := context.Background()
return w.PingWithContext(ctx)
}

func (w *Producer) PingWithContext(ctx context.Context) error {
if atomic.LoadInt32(&w.state) != StateConnected {
err := w.connect()
err := w.connect(ctx)
if err != nil {
return err
}
}

return w.conn.WriteCommand(Nop())
return w.conn.WriteCommandWithContext(ctx, Nop())
}

// SetLogger assigns the logger to use as well as a level
//
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
// Output(calldepth int, s string)
func (w *Producer) SetLogger(l logger, lvl LogLevel) {
w.logGuard.Lock()
defer w.logGuard.Unlock()
Expand Down Expand Up @@ -192,7 +200,13 @@ func (w *Producer) Stop() {
// and the response error if present
func (w *Producer) PublishAsync(topic string, body []byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
return w.sendCommandAsync(Publish(topic, body), doneChan, args)
ctx := context.Background()
return w.PublishAsyncWithContext(ctx, topic, body, doneChan, args...)
}

func (w *Producer) PublishAsyncWithContext(ctx context.Context, topic string, body []byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
return w.sendCommandAsync(ctx, Publish(topic, body), doneChan, args)
}

// MultiPublishAsync publishes a slice of message bodies to the specified topic
Expand All @@ -203,35 +217,56 @@ func (w *Producer) PublishAsync(topic string, body []byte, doneChan chan *Produc
// will receive a `ProducerTransaction` instance with the supplied variadic arguments
// and the response error if present
func (w *Producer) MultiPublishAsync(topic string, body [][]byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
ctx := context.Background()
return w.MultiPublishAsyncWithContext(ctx, topic, body, doneChan, args...)
}

func (w *Producer) MultiPublishAsyncWithContext(ctx context.Context, topic string, body [][]byte, doneChan chan *ProducerTransaction,
args ...interface{}) error {
cmd, err := MultiPublish(topic, body)
if err != nil {
return err
}
return w.sendCommandAsync(cmd, doneChan, args)
return w.sendCommandAsync(ctx, cmd, doneChan, args)
}

// Publish synchronously publishes a message body to the specified topic, returning
// an error if publish failed
func (w *Producer) Publish(topic string, body []byte) error {
return w.sendCommand(Publish(topic, body))
ctx := context.Background()
return w.PublishWithContext(ctx, topic, body)
}

func (w *Producer) PublishWithContext(ctx context.Context, topic string, body []byte) error {
return w.sendCommand(ctx, Publish(topic, body))
}

// MultiPublish synchronously publishes a slice of message bodies to the specified topic, returning
// an error if publish failed
func (w *Producer) MultiPublish(topic string, body [][]byte) error {
ctx := context.Background()
return w.MultiPublishWithContext(ctx, topic, body)
}

func (w *Producer) MultiPublishWithContext(ctx context.Context, topic string, body [][]byte) error {
cmd, err := MultiPublish(topic, body)
if err != nil {
return err
}
return w.sendCommand(cmd)
return w.sendCommand(ctx, cmd)
}

// DeferredPublish synchronously publishes a message body to the specified topic
// where the message will queue at the channel level until the timeout expires, returning
// an error if publish failed
func (w *Producer) DeferredPublish(topic string, delay time.Duration, body []byte) error {
return w.sendCommand(DeferredPublish(topic, delay, body))
ctx := context.Background()
return w.DeferredPublishWithContext(ctx, topic, delay, body)
}

func (w *Producer) DeferredPublishWithContext(ctx context.Context, topic string, delay time.Duration, body []byte) error {
return w.sendCommand(ctx, DeferredPublish(topic, delay, body))
}

// DeferredPublishAsync publishes a message body to the specified topic
Expand All @@ -244,12 +279,18 @@ func (w *Producer) DeferredPublish(topic string, delay time.Duration, body []byt
// and the response error if present
func (w *Producer) DeferredPublishAsync(topic string, delay time.Duration, body []byte,
doneChan chan *ProducerTransaction, args ...interface{}) error {
return w.sendCommandAsync(DeferredPublish(topic, delay, body), doneChan, args)
ctx := context.Background()
return w.DeferredPublishAsyncWithContext(ctx, topic, delay, body, doneChan, args...)
}

func (w *Producer) sendCommand(cmd *Command) error {
func (w *Producer) DeferredPublishAsyncWithContext(ctx context.Context, topic string, delay time.Duration, body []byte,
doneChan chan *ProducerTransaction, args ...interface{}) error {
return w.sendCommandAsync(ctx, DeferredPublish(topic, delay, body), doneChan, args)
}

func (w *Producer) sendCommand(ctx context.Context, cmd *Command) error {
doneChan := make(chan *ProducerTransaction)
err := w.sendCommandAsync(cmd, doneChan, nil)
err := w.sendCommandAsync(ctx, cmd, doneChan, nil)
if err != nil {
close(doneChan)
return err
Expand All @@ -258,21 +299,22 @@ func (w *Producer) sendCommand(cmd *Command) error {
return t.Error
}

func (w *Producer) sendCommandAsync(cmd *Command, doneChan chan *ProducerTransaction,
func (w *Producer) sendCommandAsync(ctx context.Context, cmd *Command, doneChan chan *ProducerTransaction,
args []interface{}) error {
// keep track of how many outstanding producers we're dealing with
// in order to later ensure that we clean them all up...
atomic.AddInt32(&w.concurrentProducers, 1)
defer atomic.AddInt32(&w.concurrentProducers, -1)

if atomic.LoadInt32(&w.state) != StateConnected {
err := w.connect()
err := w.connect(ctx)
if err != nil {
return err
}
}

t := &ProducerTransaction{
ctx: ctx,
cmd: cmd,
doneChan: doneChan,
Args: args,
Expand All @@ -282,12 +324,14 @@ func (w *Producer) sendCommandAsync(cmd *Command, doneChan chan *ProducerTransac
case w.transactionChan <- t:
case <-w.exitChan:
return ErrStopped
case <-ctx.Done():
return ctx.Err()
}

return nil
}

func (w *Producer) connect() error {
func (w *Producer) connect(ctx context.Context) error {
w.guard.Lock()
defer w.guard.Unlock()

Expand All @@ -312,7 +356,7 @@ func (w *Producer) connect() error {
w.conn.SetLoggerForLevel(w.logger[index], LogLevel(index), format)
}

_, err := w.conn.Connect()
_, err := w.conn.ConnectWithContext(ctx)
if err != nil {
w.conn.Close()
w.log(LogLevelError, "(%s) error connecting to nsqd - %s", w.addr, err)
Expand Down Expand Up @@ -344,9 +388,19 @@ func (w *Producer) router() {
select {
case t := <-w.transactionChan:
w.transactions = append(w.transactions, t)
err := w.conn.WriteCommand(t.cmd)
err := w.conn.WriteCommandWithContext(t.ctx, t.cmd)
if err != nil {
w.log(LogLevelError, "(%s) sending command - %s", w.conn.String(), err)

switch err {
case context.Canceled:
w.popTransaction(FrameTypeContextCanceled, []byte(err.Error()))
continue
case context.DeadlineExceeded:
w.popTransaction(FrameTypeContextDeadlineExceeded, []byte(err.Error()))
continue
}

w.close()
}
case data := <-w.responseChan:
Expand Down Expand Up @@ -380,9 +434,16 @@ func (w *Producer) popTransaction(frameType int32, data []byte) {
}
t := w.transactions[0]
w.transactions = w.transactions[1:]
if frameType == FrameTypeError {

switch frameType {
case FrameTypeError:
t.Error = ErrProtocol{string(data)}
case FrameTypeContextCanceled:
t.Error = context.Canceled
case FrameTypeContextDeadlineExceeded:
t.Error = context.DeadlineExceeded
}

t.finish()
}

Expand Down
Loading