Skip to content

Commit

Permalink
utilize popTransaction for context error cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulminator committed Jan 24, 2025
1 parent d85bbc7 commit 7ba56c7
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 31 deletions.
23 changes: 16 additions & 7 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,16 @@ func (w *Producer) router() {
err := w.conn.WriteCommandWithContext(t.ctx, t.cmd)
if err != nil {
w.log(LogLevelError, "(%s) sending command - %s", w.conn.String(), err)
if err == context.Canceled || err == context.DeadlineExceeded {
// keep the connection alive if related to context timeout
// need to do some stuff that's in Producer.popTransaction here
w.transactions = w.transactions[1:]
t.Error = err
t.finish()

switch err {
case context.Canceled:
w.popTransaction(FrameTypeContextCanceled, []byte(err.Error()))
continue
case context.DeadlineExceeded:
w.popTransaction(FrameTypeContextDeadlineExceeded, []byte(err.Error()))
continue
}

w.close()
}
case data := <-w.responseChan:
Expand Down Expand Up @@ -432,9 +434,16 @@ func (w *Producer) popTransaction(frameType int32, data []byte) {
}
t := w.transactions[0]
w.transactions = w.transactions[1:]
if frameType == FrameTypeError {

switch frameType {
case FrameTypeError:
t.Error = ErrProtocol{string(data)}
case FrameTypeContextCanceled:
t.Error = context.Canceled
case FrameTypeContextDeadlineExceeded:
t.Error = context.DeadlineExceeded
}

t.finish()
}

Expand Down
48 changes: 37 additions & 11 deletions producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ func TestProducerPublish(t *testing.T) {
func TestProducerPublishWithContext(t *testing.T) {
topicName := "publish" + strconv.Itoa(int(time.Now().Unix()))
publishAttempts := 100
publishFailures := 0
ctxCanceledCount := 0
ctxDeadlineExceededCount := 0

config := NewConfig()
w, _ := NewProducer("127.0.0.1:4150", config)
Expand All @@ -122,30 +123,55 @@ func TestProducerPublishWithContext(t *testing.T) {
// with the low timeout, the DialContext will fail and close conn, so Ping w/ out context first
w.Ping()

timeout := time.Duration(5) * time.Microsecond
timeout := 5 * time.Microsecond

var wg sync.WaitGroup
var mu sync.Mutex
var errs []error
for i := 0; i < publishAttempts; i++ {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
wg.Add(1)
go func() {
defer wg.Done()
err := w.PublishWithContext(ctx, topicName, []byte("publish_test_case"))
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}()

err := w.PublishWithContext(ctx, topicName, []byte("publish_test_case"))
if err != nil {
if err != context.DeadlineExceeded {
t.Fatalf("error %s", err)
}
publishFailures++
// this sleep enables seeing both context.Canceled and context.DeadlineExceeded errors
time.Sleep(timeout)
cancel()
}
wg.Wait()

for _, err := range errs {
switch err {
case nil:
case context.Canceled:
ctxCanceledCount++
case context.DeadlineExceeded:
ctxDeadlineExceededCount++
default:
t.Fatalf("error %s", err)
}
}

if ctxCanceledCount == 0 || ctxDeadlineExceededCount == 0 {
t.Fatalf("expected both context.Canceled and context.DeadlineExceeded errors, got %d and %d respectively", ctxCanceledCount, ctxDeadlineExceededCount)
}

err := w.Publish(topicName, []byte("bad_test_case"))
if err != nil {
t.Fatalf("error %s", err)
}

publishFailures := ctxCanceledCount + ctxDeadlineExceededCount
publishSuccesses := publishAttempts - publishFailures
if publishSuccesses == 0 || publishFailures == 0 {
t.Fatalf("expected both successful and failed publishes, got %d and %d", publishSuccesses, publishFailures)
t.Fatalf("expected both successful and failed publishes, got %d and %d respectively", publishSuccesses, publishFailures)
}
// ensure that if a context.DeadlineExceeded error is returned, no message is actually published
// ensure that if a context error is returned, no message is actually published
readMessages(topicName, t, publishSuccesses)
}

Expand Down
28 changes: 15 additions & 13 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ var MagicV2 = []byte(" V2")

// frame types
const (
FrameTypeResponse int32 = 0
FrameTypeError int32 = 1
FrameTypeMessage int32 = 2
FrameTypeResponse int32 = 0
FrameTypeError int32 = 1
FrameTypeMessage int32 = 2
FrameTypeContextCanceled int32 = 3
FrameTypeContextDeadlineExceeded int32 = 4
)

// Used to detect if an unexpected HTTP response is read
Expand Down Expand Up @@ -46,11 +48,11 @@ func isValidName(name string) bool {
// ReadResponse is a client-side utility function to read from the supplied Reader
// according to the NSQ protocol spec:
//
// [x][x][x][x][x][x][x][x]...
// | (int32) || (binary)
// | 4-byte || N-byte
// ------------------------...
// size data
// [x][x][x][x][x][x][x][x]...
// | (int32) || (binary)
// | 4-byte || N-byte
// ------------------------...
// size data
func ReadResponse(r io.Reader, maxMsgSize int32) ([]byte, error) {
var msgSize int32

Expand Down Expand Up @@ -84,11 +86,11 @@ func ReadResponse(r io.Reader, maxMsgSize int32) ([]byte, error) {
// UnpackResponse is a client-side utility function that unpacks serialized data
// according to NSQ protocol spec:
//
// [x][x][x][x][x][x][x][x]...
// | (int32) || (binary)
// | 4-byte || N-byte
// ------------------------...
// frame ID data
// [x][x][x][x][x][x][x][x]...
// | (int32) || (binary)
// | 4-byte || N-byte
// ------------------------...
// frame ID data
//
// Returns a triplicate of: frame type, data ([]byte), error
func UnpackResponse(response []byte) (int32, []byte, error) {
Expand Down

0 comments on commit 7ba56c7

Please sign in to comment.