diff --git a/batch_consumer_test.go b/batch_consumer_test.go index 5c5d950..0679d8a 100644 --- a/batch_consumer_test.go +++ b/batch_consumer_test.go @@ -2,6 +2,8 @@ package kafka import ( "errors" + "reflect" + "strconv" "sync" "testing" "time" @@ -207,6 +209,72 @@ func Test_batchConsumer_process(t *testing.T) { }) } +func Test_batchConsumer_chunk(t *testing.T) { + tests := []struct { + allMessages []*Message + chunkSize int + expected [][]*Message + }{ + { + allMessages: createMessages(0, 9), + chunkSize: 3, + expected: [][]*Message{ + createMessages(0, 3), + createMessages(3, 6), + createMessages(6, 9), + }, + }, + { + allMessages: []*Message{}, + chunkSize: 3, + expected: [][]*Message{}, + }, + { + allMessages: createMessages(0, 1), + chunkSize: 3, + expected: [][]*Message{ + createMessages(0, 1), + }, + }, + { + allMessages: createMessages(0, 8), + chunkSize: 3, + expected: [][]*Message{ + createMessages(0, 3), + createMessages(3, 6), + createMessages(6, 8), + }, + }, + { + allMessages: createMessages(0, 3), + chunkSize: 3, + expected: [][]*Message{ + createMessages(0, 3), + }, + }, + } + + for i, tc := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + chunkedMessages := chunkMessages(tc.allMessages, tc.chunkSize) + + if !reflect.DeepEqual(chunkedMessages, tc.expected) && !(len(chunkedMessages) == 0 && len(tc.expected) == 0) { + t.Errorf("For chunkSize %d, expected %v, but got %v", tc.chunkSize, tc.expected, chunkedMessages) + } + }) + } +} + +func createMessages(partitionStart int, partitionEnd int) []*Message { + messages := make([]*Message, 0) + for i := partitionStart; i < partitionEnd; i++ { + messages = append(messages, &Message{ + Partition: i, + }) + } + return messages +} + type mockCronsumer struct { wantErr bool }