From c181d90d97bacfb128d6be58e35fa77cfc9ea898 Mon Sep 17 00:00:00 2001 From: Haris Osmanagic Date: Wed, 14 May 2025 17:21:33 +0200 Subject: [PATCH 1/5] Optimize ReadN in CDC by collecting records in batches --- source.go | 16 +-- source/logrepl/cdc.go | 159 +++++++++++++++++++----- source/logrepl/cdc_test.go | 103 +++++++++++++-- source/logrepl/combined.go | 20 +-- source/logrepl/handler.go | 93 +++++++++++--- source/logrepl/internal/subscription.go | 2 +- source/snapshot/fetch_worker.go | 19 ++- source/snapshot/fetch_worker_test.go | 20 +-- source/snapshot/iterator.go | 50 ++++---- source/snapshot/iterator_test.go | 3 +- 10 files changed, 369 insertions(+), 116 deletions(-) diff --git a/source.go b/source.go index 625e8e4..604849b 100644 --- a/source.go +++ b/source.go @@ -99,14 +99,14 @@ func (s *Source) Open(ctx context.Context, pos opencdc.Position) error { fallthrough case source.CDCModeLogrepl: i, err := logrepl.NewCombinedIterator(ctx, s.pool, logrepl.Config{ - Position: pos, - SlotName: s.config.LogreplSlotName, - PublicationName: s.config.LogreplPublicationName, - Tables: s.config.Tables, - TableKeys: s.tableKeys, - WithSnapshot: s.config.SnapshotMode == source.SnapshotModeInitial, - WithAvroSchema: s.config.WithAvroSchema, - SnapshotFetchSize: s.config.SnapshotFetchSize, + Position: pos, + SlotName: s.config.LogreplSlotName, + PublicationName: s.config.LogreplPublicationName, + Tables: s.config.Tables, + TableKeys: s.tableKeys, + WithSnapshot: s.config.SnapshotMode == source.SnapshotModeInitial, + WithAvroSchema: s.config.WithAvroSchema, + BatchSize: *s.config.BatchSize, }) if err != nil { return fmt.Errorf("failed to create logical replication iterator: %w", err) diff --git a/source/logrepl/cdc.go b/source/logrepl/cdc.go index 1c99f20..5d8a282 100644 --- a/source/logrepl/cdc.go +++ b/source/logrepl/cdc.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/conduitio/conduit-commons/opencdc" "github.com/conduitio/conduit-connector-postgres/source/logrepl/internal" @@ -35,15 +36,28 @@ type CDCConfig struct { Tables []string TableKeys map[string]string WithAvroSchema bool + // BatchSize is the maximum size of a batch that will be read from the DB + // in one go and processed by the CDCHandler. + BatchSize int } // CDCIterator asynchronously listens for events from the logical replication // slot and returns them to the caller through NextN. type CDCIterator struct { - config CDCConfig - records chan opencdc.Record + config CDCConfig + sub *internal.Subscription - sub *internal.Subscription + // batchesCh is a channel shared between this iterator and a CDCHandler, + // to which the CDCHandler is sending batches of records. + // Using a shared queue here would be the fastest option. However, + // we also need to watch for a context that can get cancelled, + // and for the subscription that can end, so using a channel is + // the best option at the moment. + batchesCh chan []opencdc.Record + + // recordsForNextRead contains records from the previous batch (returned by the CDCHandler), + // that weren't return by this iterator's ReadN method. + recordsForNextRead []opencdc.Record } // NewCDCIterator initializes logical replication by creating the publication and subscription manager. @@ -64,8 +78,22 @@ func NewCDCIterator(ctx context.Context, pool *pgxpool.Pool, c CDCConfig) (*CDCI Msgf("Publication %q already exists.", c.PublicationName) } - records := make(chan opencdc.Record) - handler := NewCDCHandler(internal.NewRelationSet(), c.TableKeys, records, c.WithAvroSchema) + // Using a buffered channel here so that the handler can send a batch + // to the channel and start building a new batch. + // This is useful when the first batch in the channel didn't reach BatchSize (which is sdk.batch.size). + // The handler can prepare the next batch, and the CDCIterator can use them + // to return the maximum number of records. + batchesCh := make(chan []opencdc.Record, 1) + handler := NewCDCHandler( + ctx, + internal.NewRelationSet(), + c.TableKeys, + batchesCh, + c.WithAvroSchema, + c.BatchSize, + // todo make configurable + time.Second, + ) sub, err := internal.CreateSubscription( ctx, @@ -81,9 +109,9 @@ func NewCDCIterator(ctx context.Context, pool *pgxpool.Pool, c CDCConfig) (*CDCI } return &CDCIterator{ - config: c, - records: records, - sub: sub, + config: c, + batchesCh: batchesCh, + sub: sub, }, nil } @@ -113,8 +141,9 @@ func (i *CDCIterator) StartSubscriber(ctx context.Context) error { return nil } -// NextN takes and returns up to n records from the queue. NextN is allowed to -// block until either at least one record is available or the context gets canceled. +// NextN returns up to n records from the internal channel with records. +// NextN is allowed to block until either at least one record is available +// or the context gets canceled. func (i *CDCIterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error) { if !i.subscriberReady() { return nil, errors.New("logical replication has not been started") @@ -124,9 +153,45 @@ func (i *CDCIterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error return nil, fmt.Errorf("n must be greater than 0, got %d", n) } - var recs []opencdc.Record + // First, we check if there are any records from the previous batch + // that we can start with. + recs := make([]opencdc.Record, len(i.recordsForNextRead), n) + copy(recs, i.recordsForNextRead) + i.recordsForNextRead = nil + + // NextN needs to wait until at least 1 record is available. + if len(recs) == 0 { + batch, err := i.nextRecordsBatchBlocking(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch next batch of records (blocking): %w", err) + } + recs = batch + } + + // We add any already available batches (i.e., we're not blocking waiting for any new batches to arrive) + // to return at most n records. + for len(recs) < n { + batch, err := i.nextRecordsBatch(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch next batch of records: %w", err) + } + if batch == nil { + break + } + recs = i.appendRecordsWithLimit(recs, batch, n) + } + + sdk.Logger(ctx).Trace(). + Int("records", len(recs)). + Int("records_for_next_read", len(i.recordsForNextRead)). + Msg("CDCIterator.NextN returning records") + return recs, nil +} - // Block until at least one record is received or context is canceled +// nextRecordsBatchBlocking waits for the next batch of records to arrive, +// or for the context to be done, or for the subscription to be done, +// whichever comes first. +func (i *CDCIterator) nextRecordsBatchBlocking(ctx context.Context) ([]opencdc.Record, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -142,33 +207,59 @@ func (i *CDCIterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error // subscription stopped without an error and the context is still // open, this is a strange case, shouldn't actually happen return nil, fmt.Errorf("subscription stopped, no more data to fetch (this smells like a bug)") - case rec := <-i.records: - recs = append(recs, rec) + case batch := <-i.batchesCh: + sdk.Logger(ctx).Trace(). + Int("records", len(batch)). + Msg("CDCIterator.NextN received batch of records (blocking)") + return batch, nil } +} - for len(recs) < n { - select { - case rec := <-i.records: - recs = append(recs, rec) - case <-ctx.Done(): - return nil, ctx.Err() - case <-i.sub.Done(): - if err := i.sub.Err(); err != nil { - return recs, fmt.Errorf("logical replication error: %w", err) - } - if err := ctx.Err(); err != nil { - // Return what we have with context error - return recs, err - } - // Return what we have with subscription stopped error - return recs, fmt.Errorf("subscription stopped, no more data to fetch (this smells like a bug)") - default: - // No more records currently available - return recs, nil +func (i *CDCIterator) nextRecordsBatch(ctx context.Context) ([]opencdc.Record, error) { + select { + case <-ctx.Done(): + // Return what we have with the error + return nil, ctx.Err() + case <-i.sub.Done(): + if err := i.sub.Err(); err != nil { + return nil, fmt.Errorf("logical replication error: %w", err) } + if err := ctx.Err(); err != nil { + // Return what we have with the context error + return nil, err + } + // Return what we have with subscription stopped error + return nil, fmt.Errorf("subscription stopped, no more data to fetch (this smells like a bug)") + case batch := <-i.batchesCh: + sdk.Logger(ctx).Trace(). + Int("records", len(batch)). + Msg("CDCIterator.NextN received batch of records") + + return batch, nil + default: + // No more records currently available + return nil, nil } +} - return recs, nil +// appendRecordsWithLimit moves records from src to dst, until the given limit is reached, +// or all records from src have been moved. +// If some records from src are not moved (probably because they lack emotions), +// they are saved to recordsForNextRead. +func (i *CDCIterator) appendRecordsWithLimit(dst []opencdc.Record, src []opencdc.Record, limit int) []opencdc.Record { + if len(src) == 0 || len(dst) > limit { + return src + } + + needed := limit - len(dst) + if needed > len(src) { + needed = len(src) + } + + dst = append(dst, src[:needed]...) + i.recordsForNextRead = src[needed:] + + return dst } // Ack forwards the acknowledgment to the subscription. diff --git a/source/logrepl/cdc_test.go b/source/logrepl/cdc_test.go index f7f9b41..8586b0c 100644 --- a/source/logrepl/cdc_test.go +++ b/source/logrepl/cdc_test.go @@ -456,6 +456,93 @@ func TestCDCIterator_Ack(t *testing.T) { }) } } + +func TestCDCIterator_NextN_InternalBatching(t *testing.T) { + ctx := test.Context(t) + pool := test.ConnectPool(ctx, t, test.RepmgrConnString) + table := test.SetupEmptyTestTable(ctx, t, pool) + + is := is.New(t) + underTest := testCDCIterator(ctx, t, pool, table, true) + <-underTest.sub.Ready() + + insertTestRows(ctx, is, pool, table, 1, 1) + // wait until the CDCHandler flushes this one record + // so that we force the CDCIterator to wait for another batch + time.Sleep(time.Second * 2) + insertTestRows(ctx, is, pool, table, 2, 5) + + // we request 2 records, expect records 1 and 2 + got, err := underTest.NextN(ctx, 2) + is.NoErr(err) + verifyOpenCDCRecords(is, got, table, 1, 2) + time.Sleep(200 * time.Millisecond) + + // we request 2 records, expect records 3 and 4 + got, err = underTest.NextN(ctx, 2) + is.NoErr(err) + verifyOpenCDCRecords(is, got, table, 3, 4) + time.Sleep(200 * time.Millisecond) + + // we request 2 records, expect record 5 + got, err = underTest.NextN(ctx, 2) + is.NoErr(err) + verifyOpenCDCRecords(is, got, table, 5, 5) +} + +func insertTestRows(ctx context.Context, is *is.I, pool *pgxpool.Pool, table string, from int, to int) { + for i := from; i <= to; i++ { + _, err := pool.Exec( + ctx, + fmt.Sprintf( + `INSERT INTO %s (id, column1, column2, column3, column4, column5) + VALUES (%d, 'test-%d', %d, false, 12.3, 14)`, table, i+10, i, i*100, + ), + ) + is.NoErr(err) + } +} + +func verifyOpenCDCRecords(is *is.I, got []opencdc.Record, tableName string, fromID, toID int) { + // Build the expected records slice + var want []opencdc.Record + + for i := fromID; i <= toID; i++ { + id := int64(i + 10) + record := opencdc.Record{ + Operation: opencdc.OperationCreate, + Key: opencdc.StructuredData{ + "id": id, + }, + Payload: opencdc.Change{ + After: opencdc.StructuredData{ + "id": id, + "key": nil, + "column1": fmt.Sprintf("test-%d", i), + "column2": int32(i) * 100, //nolint:gosec // fine, we know the value is small enough + "column3": false, + "column4": 12.3, + "column5": int64(14), + "column6": nil, + "column7": nil, + "UppercaseColumn1": nil, + }, + }, + Metadata: opencdc.Metadata{ + opencdc.MetadataCollection: tableName, + }, + } + + want = append(want, record) + } + + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported(opencdc.Record{}), + cmpopts.IgnoreFields(opencdc.Record{}, "Position", "Metadata"), + } + is.Equal("", cmp.Diff(want, got, cmpOpts...)) // mismatch (-want +got) +} + func TestCDCIterator_NextN(t *testing.T) { ctx := test.Context(t) pool := test.ConnectPool(ctx, t, test.RepmgrConnString) @@ -575,17 +662,12 @@ func TestCDCIterator_NextN(t *testing.T) { VALUES (30, 'test-1', 100, false, 12.3, 14)`, table)) is.NoErr(err) - go func() { - time.Sleep(100 * time.Millisecond) - is.NoErr(i.Teardown(ctx)) - }() + time.Sleep(100 * time.Millisecond) + is.NoErr(i.Teardown(ctx)) - records, err := i.NextN(ctx, 5) - if err != nil { - is.True(strings.Contains(err.Error(), "logical replication error")) - } else { - is.True(len(records) > 0) - } + _, err = i.NextN(ctx, 5) + is.True(err != nil) + is.True(strings.Contains(err.Error(), "logical replication error")) }) } @@ -597,6 +679,7 @@ func testCDCIterator(ctx context.Context, t *testing.T, pool *pgxpool.Pool, tabl PublicationName: table, // table is random, reuse for publication name SlotName: table, // table is random, reuse for slot name WithAvroSchema: true, + BatchSize: 2, } i, err := NewCDCIterator(ctx, pool, config) diff --git a/source/logrepl/combined.go b/source/logrepl/combined.go index 7ed853d..655e6a3 100644 --- a/source/logrepl/combined.go +++ b/source/logrepl/combined.go @@ -42,14 +42,14 @@ type CombinedIterator struct { } type Config struct { - Position opencdc.Position - SlotName string - PublicationName string - Tables []string - TableKeys map[string]string - WithSnapshot bool - WithAvroSchema bool - SnapshotFetchSize int + Position opencdc.Position + SlotName string + PublicationName string + Tables []string + TableKeys map[string]string + WithSnapshot bool + WithAvroSchema bool + BatchSize int } // Validate performs validation tasks on the config. @@ -133,6 +133,7 @@ func (c *CombinedIterator) NextN(ctx context.Context, n int) ([]opencdc.Record, sdk.Logger(ctx).Debug().Msg("Snapshot completed, switching to CDC mode") return c.NextN(ctx, n) } + return records, nil } @@ -182,6 +183,7 @@ func (c *CombinedIterator) initCDCIterator(ctx context.Context, pos position.Pos Tables: c.conf.Tables, TableKeys: c.conf.TableKeys, WithAvroSchema: c.conf.WithAvroSchema, + BatchSize: c.conf.BatchSize, }) if err != nil { return fmt.Errorf("failed to create CDC iterator: %w", err) @@ -207,7 +209,7 @@ func (c *CombinedIterator) initSnapshotIterator(ctx context.Context, pos positio Tables: c.conf.Tables, TableKeys: c.conf.TableKeys, TXSnapshotID: c.cdcIterator.TXSnapshotID(), - FetchSize: c.conf.SnapshotFetchSize, + FetchSize: c.conf.BatchSize, WithAvroSchema: c.conf.WithAvroSchema, }) if err != nil { diff --git a/source/logrepl/handler.go b/source/logrepl/handler.go index a8ad9fb..535e2a5 100644 --- a/source/logrepl/handler.go +++ b/source/logrepl/handler.go @@ -17,6 +17,8 @@ package logrepl import ( "context" "fmt" + "sync" + "time" "github.com/conduitio/conduit-commons/opencdc" cschema "github.com/conduitio/conduit-commons/schema" @@ -33,22 +35,75 @@ import ( type CDCHandler struct { tableKeys map[string]string relationSet *internal.RelationSet - out chan<- opencdc.Record - lastTXLSN pglogrepl.LSN + // batchSize is the largest number of records this handler will send at once. + batchSize int + flushInterval time.Duration + + // recordBatch holds the batch that is currently being built. + recordBatch []opencdc.Record + recordBatchLock sync.Mutex + + // out is a sending channel with batches of records. + out chan<- []opencdc.Record + lastTXLSN pglogrepl.LSN withAvroSchema bool keySchemas map[string]cschema.Schema payloadSchemas map[string]cschema.Schema } -func NewCDCHandler(rs *internal.RelationSet, tableKeys map[string]string, out chan<- opencdc.Record, withAvroSchema bool) *CDCHandler { - return &CDCHandler{ +func NewCDCHandler(ctx context.Context, rs *internal.RelationSet, tableKeys map[string]string, out chan<- []opencdc.Record, withAvroSchema bool, batchSize int, flushInterval time.Duration) *CDCHandler { + h := &CDCHandler{ tableKeys: tableKeys, relationSet: rs, + recordBatch: make([]opencdc.Record, 0, batchSize), out: out, withAvroSchema: withAvroSchema, keySchemas: make(map[string]cschema.Schema), payloadSchemas: make(map[string]cschema.Schema), + batchSize: batchSize, + flushInterval: flushInterval, + } + + go h.scheduleFlushing(ctx) + + return h +} + +func (h *CDCHandler) scheduleFlushing(ctx context.Context) { + ticker := time.NewTicker(h.flushInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := h.flush(ctx) + if err != nil { + sdk.Logger(ctx).Err(err).Msg("failed flushing records") + } + } + } +} + +func (h *CDCHandler) flush(ctx context.Context) error { + h.recordBatchLock.Lock() + defer h.recordBatchLock.Unlock() + + if len(h.recordBatch) == 0 { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case h.out <- h.recordBatch: + sdk.Logger(ctx).Trace(). + Int("records", len(h.recordBatch)). + Msg("CDCHandler sending batch of records") + h.recordBatch = make([]opencdc.Record, 0, h.batchSize) + return nil } } @@ -119,7 +174,7 @@ func (h *CDCHandler) handleInsert( ) h.attachSchemas(rec, rel.RelationName) - return h.send(ctx, rec) + return h.addToBatch(ctx, rec) } // handleUpdate formats a record with UPDATE event data from Postgres and sends @@ -159,7 +214,7 @@ func (h *CDCHandler) handleUpdate( ) h.attachSchemas(rec, rel.RelationName) - return h.send(ctx, rec) + return h.addToBatch(ctx, rec) } // handleDelete formats a record with DELETE event data from Postgres and sends @@ -191,18 +246,28 @@ func (h *CDCHandler) handleDelete( ) h.attachSchemas(rec, rel.RelationName) - return h.send(ctx, rec) + return h.addToBatch(ctx, rec) } -// send the record to the output channel or detect the cancellation of the +// addToBatch the record to the output channel or detect the cancellation of the // context and return the context error. -func (h *CDCHandler) send(ctx context.Context, rec opencdc.Record) error { - select { - case <-ctx.Done(): - return ctx.Err() - case h.out <- rec: - return nil +func (h *CDCHandler) addToBatch(ctx context.Context, rec opencdc.Record) error { + h.recordBatchLock.Lock() + + h.recordBatch = append(h.recordBatch, rec) + currentBatchSize := len(h.recordBatch) + + sdk.Logger(ctx).Trace(). + Int("current_batch_size", currentBatchSize). + Msg("CDCHandler added record to batch") + + h.recordBatchLock.Unlock() + + if currentBatchSize >= h.batchSize { + return h.flush(ctx) } + + return nil } func (h *CDCHandler) buildRecordMetadata(rel *pglogrepl.RelationMessage) map[string]string { diff --git a/source/logrepl/internal/subscription.go b/source/logrepl/internal/subscription.go index 193356c..5592a02 100644 --- a/source/logrepl/internal/subscription.go +++ b/source/logrepl/internal/subscription.go @@ -193,7 +193,7 @@ func (s *Subscription) listen(ctx context.Context) error { copyDataMsg, ok := msg.(*pgproto3.CopyData) if !ok { - return fmt.Errorf("unexpected message type %T", msg) + return fmt.Errorf("unexpected message type %T, value: %v", msg, msg) } switch copyDataMsg.Data[0] { diff --git a/source/snapshot/fetch_worker.go b/source/snapshot/fetch_worker.go index cae1817..2f52787 100644 --- a/source/snapshot/fetch_worker.go +++ b/source/snapshot/fetch_worker.go @@ -92,7 +92,7 @@ type FetchData struct { type FetchWorker struct { conf FetchConfig db *pgxpool.Pool - out chan<- FetchData + out chan<- []FetchData keySchema *cschema.Schema payloadSchema *cschema.Schema @@ -102,7 +102,7 @@ type FetchWorker struct { cursorName string } -func NewFetchWorker(db *pgxpool.Pool, out chan<- FetchData, c FetchConfig) *FetchWorker { +func NewFetchWorker(db *pgxpool.Pool, out chan<- []FetchData, c FetchConfig) *FetchWorker { f := &FetchWorker{ conf: c, db: db, @@ -285,6 +285,8 @@ func (f *FetchWorker) fetch(ctx context.Context, tx pgx.Tx) (int, error) { fields := rows.FieldDescriptions() var nread int + var toBeSent []FetchData + for rows.Next() { values, err := rows.Values() if err != nil { @@ -301,12 +303,17 @@ func (f *FetchWorker) fetch(ctx context.Context, tx pgx.Tx) (int, error) { return nread, fmt.Errorf("failed to build fetch data: %w", err) } - if err := f.send(ctx, data); err != nil { + toBeSent = append(toBeSent, data) + nread++ + } + + if nread > 0 { + err := f.send(ctx, toBeSent) + if err != nil { return nread, fmt.Errorf("failed to send record: %w", err) } - - nread++ } + if rows.Err() != nil { return 0, fmt.Errorf("failed to read rows: %w", rows.Err()) } @@ -314,7 +321,7 @@ func (f *FetchWorker) fetch(ctx context.Context, tx pgx.Tx) (int, error) { return nread, nil } -func (f *FetchWorker) send(ctx context.Context, d FetchData) error { +func (f *FetchWorker) send(ctx context.Context, d []FetchData) error { start := time.Now().UTC() defer func() { sdk.Logger(ctx).Trace(). diff --git a/source/snapshot/fetch_worker_test.go b/source/snapshot/fetch_worker_test.go index 6c5cd00..5b7bf9c 100644 --- a/source/snapshot/fetch_worker_test.go +++ b/source/snapshot/fetch_worker_test.go @@ -36,7 +36,7 @@ import ( func Test_NewFetcher(t *testing.T) { t.Run("with initial position", func(t *testing.T) { is := is.New(t) - f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- FetchData), FetchConfig{}) + f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- []FetchData), FetchConfig{}) is.Equal(f.snapshotEnd, int64(0)) is.Equal(f.lastRead, int64(0)) @@ -44,7 +44,7 @@ func Test_NewFetcher(t *testing.T) { t.Run("with missing position data", func(t *testing.T) { is := is.New(t) - f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- FetchData), FetchConfig{ + f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- []FetchData), FetchConfig{ Position: position.Position{ Type: position.TypeSnapshot, }, @@ -57,7 +57,7 @@ func Test_NewFetcher(t *testing.T) { t.Run("resume from position", func(t *testing.T) { is := is.New(t) - f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- FetchData), FetchConfig{ + f := NewFetchWorker(&pgxpool.Pool{}, make(chan<- []FetchData), FetchConfig{ Position: position.Position{ Type: position.TypeSnapshot, Snapshots: position.SnapshotPositions{ @@ -210,7 +210,7 @@ func Test_FetcherRun_EmptySnapshot(t *testing.T) { ctx = test.Context(t) pool = test.ConnectPool(context.Background(), t, test.RegularConnString) table = test.SetupEmptyTestTable(context.Background(), t, pool) - out = make(chan FetchData) + out = make(chan []FetchData) testTomb = &tomb.Tomb{} ) @@ -231,7 +231,7 @@ func Test_FetcherRun_EmptySnapshot(t *testing.T) { var gotFetchData []FetchData for data := range out { - gotFetchData = append(gotFetchData, data) + gotFetchData = append(gotFetchData, data...) } is.NoErr(testTomb.Err()) @@ -243,7 +243,7 @@ func Test_FetcherRun_Initial(t *testing.T) { pool = test.ConnectPool(context.Background(), t, test.RegularConnString) table = test.SetupTestTable(context.Background(), t, pool) is = is.New(t) - out = make(chan FetchData) + out = make(chan []FetchData) ctx = test.Context(t) tt = &tomb.Tomb{} ) @@ -265,7 +265,7 @@ func Test_FetcherRun_Initial(t *testing.T) { var gotFetchData []FetchData for data := range out { - gotFetchData = append(gotFetchData, data) + gotFetchData = append(gotFetchData, data...) } is.NoErr(tt.Err()) @@ -303,7 +303,7 @@ func Test_FetcherRun_Resume(t *testing.T) { pool = test.ConnectPool(context.Background(), t, test.RegularConnString) table = test.SetupTestTable(context.Background(), t, pool) is = is.New(t) - out = make(chan FetchData) + out = make(chan []FetchData) ctx = test.Context(t) tt = &tomb.Tomb{} ) @@ -334,7 +334,7 @@ func Test_FetcherRun_Resume(t *testing.T) { var dd []FetchData for d := range out { - dd = append(dd, d) + dd = append(dd, d...) } is.NoErr(tt.Err()) @@ -439,7 +439,7 @@ func Test_send(t *testing.T) { cancel() - err := f.send(ctx, FetchData{}) + err := f.send(ctx, []FetchData{{}}) is.Equal(err, context.Canceled) } diff --git a/source/snapshot/iterator.go b/source/snapshot/iterator.go index 8034fef..bafa26f 100644 --- a/source/snapshot/iterator.go +++ b/source/snapshot/iterator.go @@ -40,16 +40,17 @@ type Config struct { } type Iterator struct { - db *pgxpool.Pool - t *tomb.Tomb - workers []*FetchWorker - acks csync.WaitGroup + db *pgxpool.Pool + + workersTomb *tomb.Tomb + workers []*FetchWorker + acks csync.WaitGroup conf Config lastPosition position.Position - data chan FetchData + data chan []FetchData } func NewIterator(ctx context.Context, db *pgxpool.Pool, c Config) (*Iterator, error) { @@ -65,9 +66,9 @@ func NewIterator(ctx context.Context, db *pgxpool.Pool, c Config) (*Iterator, er t, _ := tomb.WithContext(ctx) i := &Iterator{ db: db, - t: t, + workersTomb: t, conf: c, - data: make(chan FetchData), + data: make(chan []FetchData), lastPosition: p, } @@ -93,9 +94,9 @@ func (i *Iterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error) { select { case <-ctx.Done(): return nil, fmt.Errorf("iterator stopped: %w", ctx.Err()) - case d, ok := <-i.data: + case batch, ok := <-i.data: if !ok { // closed - if err := i.t.Err(); err != nil { + if err := i.workersTomb.Err(); err != nil { return nil, fmt.Errorf("fetchers exited unexpectedly: %w", err) } if err := i.acks.Wait(ctx); err != nil { @@ -104,8 +105,10 @@ func (i *Iterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error) { return nil, ErrIteratorDone } - i.acks.Add(1) - records = append(records, i.buildRecord(d)) + for _, d := range batch { + i.acks.Add(1) + records = append(records, i.buildRecord(d)) + } } // Try to get remaining records non-blocking @@ -113,12 +116,14 @@ func (i *Iterator) NextN(ctx context.Context, n int) ([]opencdc.Record, error) { select { case <-ctx.Done(): return records, ctx.Err() - case d, ok := <-i.data: + case batch, ok := <-i.data: if !ok { // closed return records, nil } - i.acks.Add(1) - records = append(records, i.buildRecord(d)) + for _, d := range batch { + i.acks.Add(1) + records = append(records, i.buildRecord(d)) + } default: // No more records currently available return records, nil @@ -134,8 +139,8 @@ func (i *Iterator) Ack(_ context.Context, _ opencdc.Position) error { } func (i *Iterator) Teardown(_ context.Context) error { - if i.t != nil { - i.t.Kill(errors.New("tearing down snapshot iterator")) + if i.workersTomb != nil { + i.workersTomb.Kill(errors.New("tearing down snapshot iterator")) } return nil @@ -185,18 +190,17 @@ func (i *Iterator) initFetchers(ctx context.Context) error { } func (i *Iterator) startWorkers() { - for j := range i.workers { - f := i.workers[j] - i.t.Go(func() error { - ctx := i.t.Context(nil) //nolint:staticcheck // This is the correct usage of tomb.Context - if err := f.Run(ctx); err != nil { - return fmt.Errorf("fetcher for table %q exited: %w", f.conf.Table, err) + for _, worker := range i.workers { + i.workersTomb.Go(func() error { + ctx := i.workersTomb.Context(nil) //nolint:staticcheck // This is the correct usage of tomb.Context + if err := worker.Run(ctx); err != nil { + return fmt.Errorf("fetcher for table %q exited: %w", worker.conf.Table, err) } return nil }) } go func() { - <-i.t.Dead() + <-i.workersTomb.Dead() close(i.data) }() } diff --git a/source/snapshot/iterator_test.go b/source/snapshot/iterator_test.go index 610f2aa..8baa792 100644 --- a/source/snapshot/iterator_test.go +++ b/source/snapshot/iterator_test.go @@ -42,6 +42,7 @@ func Test_Iterator_NextN(t *testing.T) { TableKeys: map[string]string{ table: "id", }, + FetchSize: 2, }) is.NoErr(err) defer func() { @@ -57,7 +58,7 @@ func Test_Iterator_NextN(t *testing.T) { is.Equal(r.Metadata[opencdc.MetadataCollection], table) } - // Get remaining 2 records + // Get the remaining 2 records records, err = i.NextN(ctx, 2) is.NoErr(err) is.Equal(len(records), 2) From d0f6f29599ce06e0fc70031e2107e21cec8b865d Mon Sep 17 00:00:00 2001 From: Haris Osmanagic Date: Thu, 15 May 2025 12:06:11 +0200 Subject: [PATCH 2/5] more tests, less errors --- source/logrepl/handler.go | 44 +++++++------ source/logrepl/handler_test.go | 87 +++++++++++++++++++++++++ source/logrepl/internal/subscription.go | 3 +- 3 files changed, 114 insertions(+), 20 deletions(-) create mode 100644 source/logrepl/handler_test.go diff --git a/source/logrepl/handler.go b/source/logrepl/handler.go index 535e2a5..9147095 100644 --- a/source/logrepl/handler.go +++ b/source/logrepl/handler.go @@ -52,7 +52,15 @@ type CDCHandler struct { payloadSchemas map[string]cschema.Schema } -func NewCDCHandler(ctx context.Context, rs *internal.RelationSet, tableKeys map[string]string, out chan<- []opencdc.Record, withAvroSchema bool, batchSize int, flushInterval time.Duration) *CDCHandler { +func NewCDCHandler( + ctx context.Context, + rs *internal.RelationSet, + tableKeys map[string]string, + out chan<- []opencdc.Record, + withAvroSchema bool, + batchSize int, + flushInterval time.Duration, +) *CDCHandler { h := &CDCHandler{ tableKeys: tableKeys, relationSet: rs, @@ -76,34 +84,31 @@ func (h *CDCHandler) scheduleFlushing(ctx context.Context) { for { select { - case <-ctx.Done(): - return case <-ticker.C: - err := h.flush(ctx) - if err != nil { - sdk.Logger(ctx).Err(err).Msg("failed flushing records") - } + h.flush(ctx) } } } -func (h *CDCHandler) flush(ctx context.Context) error { +func (h *CDCHandler) flush(ctx context.Context) { h.recordBatchLock.Lock() defer h.recordBatchLock.Unlock() if len(h.recordBatch) == 0 { - return nil + return } select { case <-ctx.Done(): - return ctx.Err() + close(h.out) + sdk.Logger(ctx).Warn(). + Int("records", len(h.recordBatch)). + Msg("CDCHandler flushing records cancelled") case h.out <- h.recordBatch: sdk.Logger(ctx).Trace(). Int("records", len(h.recordBatch)). Msg("CDCHandler sending batch of records") h.recordBatch = make([]opencdc.Record, 0, h.batchSize) - return nil } } @@ -173,8 +178,9 @@ func (h *CDCHandler) handleInsert( h.buildRecordPayload(newValues), ) h.attachSchemas(rec, rel.RelationName) + h.addToBatch(ctx, rec) - return h.addToBatch(ctx, rec) + return nil } // handleUpdate formats a record with UPDATE event data from Postgres and sends @@ -213,8 +219,9 @@ func (h *CDCHandler) handleUpdate( h.buildRecordPayload(newValues), ) h.attachSchemas(rec, rel.RelationName) + h.addToBatch(ctx, rec) - return h.addToBatch(ctx, rec) + return nil } // handleDelete formats a record with DELETE event data from Postgres and sends @@ -245,13 +252,14 @@ func (h *CDCHandler) handleDelete( h.buildRecordPayload(oldValues), ) h.attachSchemas(rec, rel.RelationName) + h.addToBatch(ctx, rec) - return h.addToBatch(ctx, rec) + return nil } // addToBatch the record to the output channel or detect the cancellation of the // context and return the context error. -func (h *CDCHandler) addToBatch(ctx context.Context, rec opencdc.Record) error { +func (h *CDCHandler) addToBatch(ctx context.Context, rec opencdc.Record) { h.recordBatchLock.Lock() h.recordBatch = append(h.recordBatch, rec) @@ -264,10 +272,8 @@ func (h *CDCHandler) addToBatch(ctx context.Context, rec opencdc.Record) error { h.recordBatchLock.Unlock() if currentBatchSize >= h.batchSize { - return h.flush(ctx) + h.flush(ctx) } - - return nil } func (h *CDCHandler) buildRecordMetadata(rel *pglogrepl.RelationMessage) map[string]string { @@ -309,7 +315,7 @@ func (*CDCHandler) buildPosition(lsn pglogrepl.LSN) opencdc.Position { }.ToSDKPosition() } -// updateAvroSchema generates and stores avro schema based on the relation's row, +// updateAvroSchema generates and stores avro schema based on the relation's row // when usage of avro schema is requested. func (h *CDCHandler) updateAvroSchema(ctx context.Context, rel *pglogrepl.RelationMessage) error { if !h.withAvroSchema { diff --git a/source/logrepl/handler_test.go b/source/logrepl/handler_test.go new file mode 100644 index 0000000..6f65679 --- /dev/null +++ b/source/logrepl/handler_test.go @@ -0,0 +1,87 @@ +// Copyright © 2025 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logrepl + +import ( + "context" + "testing" + "time" + + "github.com/conduitio/conduit-commons/cchan" + "github.com/conduitio/conduit-commons/opencdc" + "github.com/matryer/is" +) + +func TestHandler_Batching_BatchSizeReached(t *testing.T) { + ctx := context.Background() + is := is.New(t) + + ch := make(chan []opencdc.Record, 1) + underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, time.Second) + want := make([]opencdc.Record, 5) + for i := 0; i < cap(want); i++ { + rec := newTestRecord(i) + underTest.addToBatch(ctx, rec) + want[i] = rec + } + + recs, gotRecs, err := cchan.ChanOut[[]opencdc.Record](ch).RecvTimeout(ctx, time.Second) + is.NoErr(err) + is.True(gotRecs) + is.Equal(recs, want) +} + +func TestHandler_Batching_WaitForTimeout(t *testing.T) { + ctx := context.Background() + is := is.New(t) + + ch := make(chan []opencdc.Record, 1) + underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, time.Second) + want := make([]opencdc.Record, 3) + for i := 0; i < cap(want); i++ { + rec := newTestRecord(i) + underTest.addToBatch(ctx, rec) + want[i] = rec + } + + start := time.Now() + recs, gotRecs, err := cchan.ChanOut[[]opencdc.Record](ch).RecvTimeout(ctx, 1200*time.Millisecond) + is.NoErr(err) + is.True(gotRecs) + is.Equal(recs, want) + is.True(time.Since(start) >= time.Second) +} + +func TestHandler_Batching_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + is := is.New(t) + + ch := make(chan []opencdc.Record, 1) + underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, time.Second) + cancel() + + underTest.addToBatch(ctx, newTestRecord(0)) + + _, gotValue := <-ch + is.True(!gotValue) +} + +func newTestRecord(id int) opencdc.Record { + return opencdc.Record{ + Key: opencdc.StructuredData{ + "id": id, + }, + } +} diff --git a/source/logrepl/internal/subscription.go b/source/logrepl/internal/subscription.go index 5592a02..7c09928 100644 --- a/source/logrepl/internal/subscription.go +++ b/source/logrepl/internal/subscription.go @@ -141,7 +141,8 @@ func CreateSubscription( }, nil } -// Run logical replication listener and block until error or ctx is canceled. +// Run the logical replication listener and block until it returns an error, +// or the context is canceled. func (s *Subscription) Run(ctx context.Context) error { defer s.doneReplication() From 5059add4436a03db221f509661df447e3ad538db Mon Sep 17 00:00:00 2001 From: Haris Osmanagic Date: Thu, 15 May 2025 13:21:57 +0200 Subject: [PATCH 3/5] simplify --- source/logrepl/handler.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/source/logrepl/handler.go b/source/logrepl/handler.go index 9147095..be1173b 100644 --- a/source/logrepl/handler.go +++ b/source/logrepl/handler.go @@ -82,11 +82,8 @@ func (h *CDCHandler) scheduleFlushing(ctx context.Context) { ticker := time.NewTicker(h.flushInterval) defer ticker.Stop() - for { - select { - case <-ticker.C: - h.flush(ctx) - } + for range time.Tick(h.flushInterval) { + h.flush(ctx) } } From af2a5a2784a3e66e046c27445cef1e53df7f1076 Mon Sep 17 00:00:00 2001 From: Haris Osmanagic Date: Thu, 15 May 2025 14:05:37 +0200 Subject: [PATCH 4/5] fix ctx handling --- source/logrepl/handler.go | 17 ++++++++++------- source/logrepl/handler_test.go | 6 +++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/source/logrepl/handler.go b/source/logrepl/handler.go index be1173b..77ebfa2 100644 --- a/source/logrepl/handler.go +++ b/source/logrepl/handler.go @@ -16,6 +16,7 @@ package logrepl import ( "context" + "errors" "fmt" "sync" "time" @@ -95,18 +96,20 @@ func (h *CDCHandler) flush(ctx context.Context) { return } - select { - case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { close(h.out) sdk.Logger(ctx).Warn(). + Err(ctx.Err()). Int("records", len(h.recordBatch)). Msg("CDCHandler flushing records cancelled") - case h.out <- h.recordBatch: - sdk.Logger(ctx).Trace(). - Int("records", len(h.recordBatch)). - Msg("CDCHandler sending batch of records") - h.recordBatch = make([]opencdc.Record, 0, h.batchSize) + return } + + h.out <- h.recordBatch + sdk.Logger(ctx).Debug(). + Int("records", len(h.recordBatch)). + Msg("CDCHandler sending batch of records") + h.recordBatch = make([]opencdc.Record, 0, h.batchSize) } // Handle is the handler function that receives all logical replication messages. diff --git a/source/logrepl/handler_test.go b/source/logrepl/handler_test.go index 6f65679..fc61537 100644 --- a/source/logrepl/handler_test.go +++ b/source/logrepl/handler_test.go @@ -71,11 +71,11 @@ func TestHandler_Batching_ContextCancelled(t *testing.T) { ch := make(chan []opencdc.Record, 1) underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, time.Second) cancel() - + <-ctx.Done() underTest.addToBatch(ctx, newTestRecord(0)) - _, gotValue := <-ch - is.True(!gotValue) + _, recordReceived := <-ch + is.True(!recordReceived) } func newTestRecord(id int) opencdc.Record { From a65911667f3091547c79da5abb84d431a57b42c7 Mon Sep 17 00:00:00 2001 From: Haris Osmanagic Date: Thu, 15 May 2025 14:10:37 +0200 Subject: [PATCH 5/5] add comment --- source/logrepl/handler_test.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/source/logrepl/handler_test.go b/source/logrepl/handler_test.go index fc61537..5fb0f31 100644 --- a/source/logrepl/handler_test.go +++ b/source/logrepl/handler_test.go @@ -43,12 +43,16 @@ func TestHandler_Batching_BatchSizeReached(t *testing.T) { is.Equal(recs, want) } -func TestHandler_Batching_WaitForTimeout(t *testing.T) { +// TestHandler_Batching_FlushInterval tests if the handler flushes +// a batch once the flush interval passes, even if the batch size is not reached. +func TestHandler_Batching_FlushInterval(t *testing.T) { ctx := context.Background() is := is.New(t) ch := make(chan []opencdc.Record, 1) - underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, time.Second) + flushInterval := time.Second + underTest := NewCDCHandler(ctx, nil, nil, ch, false, 5, flushInterval) + want := make([]opencdc.Record, 3) for i := 0; i < cap(want); i++ { rec := newTestRecord(i) @@ -58,10 +62,11 @@ func TestHandler_Batching_WaitForTimeout(t *testing.T) { start := time.Now() recs, gotRecs, err := cchan.ChanOut[[]opencdc.Record](ch).RecvTimeout(ctx, 1200*time.Millisecond) + is.NoErr(err) is.True(gotRecs) is.Equal(recs, want) - is.True(time.Since(start) >= time.Second) + is.True(time.Since(start) >= flushInterval) } func TestHandler_Batching_ContextCancelled(t *testing.T) {