diff --git a/Makefile b/Makefile index fecda8c..af2f0d9 100644 --- a/Makefile +++ b/Makefile @@ -39,3 +39,6 @@ integration-test: ## run-act: act for running github actions on your local machine run-act: act -j test --container-architecture linux/arm64 + +run-benchmarks: + go test -run none -bench . -benchtime=5s \ No newline at end of file diff --git a/batch_consumer.go b/batch_consumer.go index bffcb66..2dfde32 100644 --- a/batch_consumer.go +++ b/batch_consumer.go @@ -2,6 +2,7 @@ package kafka import ( "errors" + "fmt" "time" "github.com/prometheus/client_golang/prometheus" @@ -143,33 +144,36 @@ func (b *batchConsumer) setupConcurrentWorkers() { } } -func chunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [][]*Message { +func chunkMessagesOptimized(allMessages []*Message, chunkSize int, chunkByteSize int) [][]*Message { + if chunkSize <= 0 { + panic("chunkSize must be greater than 0") + } + var chunks [][]*Message + totalMessages := len(allMessages) + estimatedChunks := (totalMessages + chunkSize - 1) / chunkSize + chunks = make([][]*Message, 0, estimatedChunks) - allMessageList := *allMessages var currentChunk []*Message - currentChunkSize := 0 + currentChunk = make([]*Message, 0, chunkSize) currentChunkBytes := 0 - for _, message := range allMessageList { + for _, message := range allMessages { messageByteSize := len(message.Value) // Check if adding this message would exceed either the chunk size or the byte size - if len(currentChunk) >= chunkSize || (chunkByteSize != 0 && currentChunkBytes+messageByteSize > chunkByteSize) { - // Avoid too low chunkByteSize + if len(currentChunk) >= chunkSize || (chunkByteSize > 0 && currentChunkBytes+messageByteSize > chunkByteSize) { if len(currentChunk) == 0 { - panic("invalid chunk byte size, please increase it") + panic(fmt.Sprintf("invalid chunk byte size (messageGroupByteSizeLimit) %d, "+ + "message byte size is %d, bigger!, increase chunk byte size limit", chunkByteSize, messageByteSize)) } - // If it does, finalize the current chunk and start a new one chunks = append(chunks, currentChunk) - currentChunk = []*Message{} - currentChunkSize = 0 + currentChunk = make([]*Message, 0, chunkSize) currentChunkBytes = 0 } // Add the message to the current chunk currentChunk = append(currentChunk, message) - currentChunkSize++ currentChunkBytes += messageByteSize } @@ -182,11 +186,11 @@ func chunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [] } func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message, messageByteSizeLimit *int) { - chunks := chunkMessages(allMessages, b.messageGroupLimit, b.messageGroupByteSizeLimit) + chunks := chunkMessagesOptimized(*allMessages, b.messageGroupLimit, b.messageGroupByteSizeLimit) if b.preBatchFn != nil { preBatchResult := b.preBatchFn(*allMessages) - chunks = chunkMessages(&preBatchResult, b.messageGroupLimit, b.messageGroupByteSizeLimit) + chunks = chunkMessagesOptimized(preBatchResult, b.messageGroupLimit, b.messageGroupByteSizeLimit) } // Send the messages to process diff --git a/batch_consumer_test.go b/batch_consumer_test.go index 0981922..b49d546 100644 --- a/batch_consumer_test.go +++ b/batch_consumer_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "reflect" - "strconv" "sync" "testing" "time" @@ -299,79 +298,6 @@ func Test_batchConsumer_process(t *testing.T) { }) } -func Test_batchConsumer_chunk(t *testing.T) { - tests := []struct { - allMessages []*Message - expected [][]*Message - chunkSize int - chunkByteSize int - }{ - { - allMessages: createMessages(0, 9), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 3), - createMessages(3, 6), - createMessages(6, 9), - }, - }, - { - allMessages: []*Message{}, - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{}, - }, - { - allMessages: createMessages(0, 1), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 1), - }, - }, - { - allMessages: createMessages(0, 8), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 3), - createMessages(3, 6), - createMessages(6, 8), - }, - }, - { - allMessages: createMessages(0, 3), - chunkSize: 3, - chunkByteSize: 10000, - expected: [][]*Message{ - createMessages(0, 3), - }, - }, - - { - allMessages: createMessages(0, 3), - chunkSize: 100, - chunkByteSize: 4, - expected: [][]*Message{ - createMessages(0, 1), - createMessages(1, 2), - createMessages(2, 3), - }, - }, - } - - for i, tc := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize, tc.chunkByteSize) - - if !reflect.DeepEqual(chunkedMessages, tc.expected) && !(len(chunkedMessages) == 0 && len(tc.expected) == 0) { - t.Errorf("For chunkSize %d, expected %v, but got %v", tc.chunkSize, tc.expected, chunkedMessages) - } - }) - } -} - func Test_batchConsumer_Pause(t *testing.T) { // Given ctx, cancelFn := context.WithCancel(context.Background()) @@ -456,6 +382,187 @@ func Test_batchConsumer_runKonsumerFn(t *testing.T) { }) } +func Test_batchConsumer_chunk(t *testing.T) { + type testCase struct { + name string + allMessages []*Message + chunkSize int + chunkByteSize int + expected [][]*Message + shouldPanic bool + } + + tests := []testCase{ + { + name: "Should_Return_3_Chunks_For_9_Messages", + allMessages: createMessages(0, 9), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 3), + createMessages(3, 6), + createMessages(6, 9), + }, + shouldPanic: false, + }, + { + name: "Should_Return_Empty_Slice_When_Input_Is_Empty", + allMessages: []*Message{}, + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{}, + shouldPanic: false, + }, + { + name: "Should_Return_Single_Chunk_When_Single_Message", + allMessages: createMessages(0, 1), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 1), + }, + shouldPanic: false, + }, + { + name: "Should_Splits_Into_Multiple_Chunks_With_Incomplete_Final_Chunk", + allMessages: createMessages(0, 8), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 3), + createMessages(3, 6), + createMessages(6, 8), + }, + shouldPanic: false, + }, + { + name: "Should_Return_Exact_Chunk_Size_Forms_Single_Chunk", + allMessages: createMessages(0, 3), + chunkSize: 3, + chunkByteSize: 10000, + expected: [][]*Message{ + createMessages(0, 3), + }, + shouldPanic: false, + }, + { + name: "Should_Forces_Single_Message_Per_Chunk_When_Small_chunkByteSize_Is_Given", + allMessages: createMessages(0, 3), + chunkSize: 100, + chunkByteSize: 4, // Each message has Value size 4 + expected: [][]*Message{ + createMessages(0, 1), + createMessages(1, 2), + createMessages(2, 3), + }, + shouldPanic: false, + }, + { + name: "Should_Ignore_Byte_Size_When_chunkByteSize=0", + allMessages: createMessages(0, 5), + chunkSize: 2, + chunkByteSize: 0, + expected: [][]*Message{ + createMessages(0, 2), + createMessages(2, 4), + createMessages(4, 5), + }, + shouldPanic: false, + }, + { + name: "Should_Panic_When_chunkByteSize_Less_Than_Message_Size", + allMessages: createMessages(0, 1), + chunkSize: 2, + chunkByteSize: 3, // Message size is 4 + expected: nil, + shouldPanic: true, + }, + { + name: "Should_Panic_When_chunkSize=0", + allMessages: createMessages(0, 1), + chunkSize: 0, + chunkByteSize: 10000, + expected: nil, + shouldPanic: true, + }, + { + name: "Should_Panic_When_Negative_chunkSize", + allMessages: createMessages(0, 1), + chunkSize: -1, + chunkByteSize: 10000, + expected: nil, + shouldPanic: true, + }, + { + name: "Should_Return_Exact_chunkByteSize", + allMessages: createMessages(0, 4), + chunkSize: 2, + chunkByteSize: 8, // Each message has Value size 4, total 16 bytes + expected: [][]*Message{ + createMessages(0, 2), + createMessages(2, 4), + }, + shouldPanic: false, + }, + { + name: "Should_Handle_Varying_Message_Byte_Sizes", + allMessages: []*Message{ + {Partition: 0, Value: []byte("a")}, // 1 byte + {Partition: 1, Value: []byte("ab")}, // 2 bytes + {Partition: 2, Value: []byte("abc")}, // 3 bytes + {Partition: 3, Value: []byte("abcd")}, // 4 bytes + }, + chunkSize: 3, + chunkByteSize: 6, + expected: [][]*Message{ + { + {Partition: 0, Value: []byte("a")}, + {Partition: 1, Value: []byte("ab")}, + {Partition: 2, Value: []byte("abc")}, + }, + { + {Partition: 3, Value: []byte("abcd")}, + }, + }, + shouldPanic: false, + }, + } + + for _, tc := range tests { + tc := tc // Capture range variable + t.Run(tc.name, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for test case '%s', but did not panic", tc.name) + } + }() + } + + chunkedMessages := chunkMessagesOptimized(tc.allMessages, tc.chunkSize, tc.chunkByteSize) + + if !tc.shouldPanic { + // Verify the number of chunks + if len(chunkedMessages) != len(tc.expected) { + t.Errorf("Test case '%s': expected %d chunks, got %d", tc.name, len(tc.expected), len(chunkedMessages)) + } + + // Verify each chunk's content + for i, expectedChunk := range tc.expected { + if i >= len(chunkedMessages) { + t.Errorf("Test case '%s': missing chunk %d", tc.name, i) + continue + } + actualChunk := chunkedMessages[i] + if !messagesEqual(actualChunk, expectedChunk) { + t.Errorf("Test case '%s': expected chunk %d to be %v, but got %v", tc.name, i, expectedChunk, actualChunk) + } + } + } + }) + } +} + func createMessages(partitionStart int, partitionEnd int) []*Message { messages := make([]*Message, 0) for i := partitionStart; i < partitionEnd; i++ { @@ -467,6 +574,21 @@ func createMessages(partitionStart int, partitionEnd int) []*Message { return messages } +func messagesEqual(a, b []*Message) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].Partition != b[i].Partition { + return false + } + if !reflect.DeepEqual(a[i].Value, b[i].Value) { + return false + } + } + return true +} + type mockCronsumer struct { wantErr bool } diff --git a/chunkMessages_benchmark_test.go b/chunkMessages_benchmark_test.go new file mode 100644 index 0000000..b977c17 --- /dev/null +++ b/chunkMessages_benchmark_test.go @@ -0,0 +1,85 @@ +package kafka + +import ( + "math/rand" + "testing" + "time" +) + +func BenchmarkChunkMessages(b *testing.B) { + b.ReportAllocs() + rand.New(rand.NewSource(time.Now().UnixNano())) + messages := generateMessages(10000, 100) // 10,000 messages, each 100 bytes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Create a copy of the messages slice to prevent compiler optimizations + msgsCopy := make([]*Message, len(messages)) + copy(msgsCopy, messages) + oldChunkMessages(&msgsCopy, 100, 10000) + } +} + +func BenchmarkChunkMessagesOptimized(b *testing.B) { + b.ReportAllocs() + rand.New(rand.NewSource(time.Now().UnixNano())) + messages := generateMessages(10000, 100) // 10,000 messages, each 100 bytes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Create a copy of the messages slice to prevent compiler optimizations + msgsCopy := make([]*Message, len(messages)) + copy(msgsCopy, messages) + chunkMessagesOptimized(msgsCopy, 100, 10000) + } +} + +func oldChunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [][]*Message { + var chunks [][]*Message + + allMessageList := *allMessages + var currentChunk []*Message + currentChunkSize := 0 + currentChunkBytes := 0 + + for _, message := range allMessageList { + messageByteSize := len(message.Value) + + // Check if adding this message would exceed either the chunk size or the byte size + if len(currentChunk) >= chunkSize || (chunkByteSize != 0 && currentChunkBytes+messageByteSize > chunkByteSize) { + // Avoid too low chunkByteSize + if len(currentChunk) == 0 { + panic("invalid chunk byte size, please increase it") + } + // If it does, finalize the current chunk and start a new one + chunks = append(chunks, currentChunk) + currentChunk = []*Message{} + currentChunkSize = 0 + currentChunkBytes = 0 + } + + // Add the message to the current chunk + currentChunk = append(currentChunk, message) + currentChunkSize++ + currentChunkBytes += messageByteSize + } + + // Add the last chunk if it has any messages + if len(currentChunk) > 0 { + chunks = append(chunks, currentChunk) + } + + return chunks +} + +func generateMessages(count int, valueSize int) []*Message { + messages := make([]*Message, count) + for i := 0; i < count; i++ { + b := make([]byte, valueSize) + for j := range b { + b[j] = byte(rand.Intn(26) + 97) + } + messages[i] = &Message{Value: b} + } + return messages +}