diff --git a/pqarrow/arrowutils/merge.go b/pqarrow/arrowutils/merge.go index e9d9d1117..8df2a9cb3 100644 --- a/pqarrow/arrowutils/merge.go +++ b/pqarrow/arrowutils/merge.go @@ -4,6 +4,7 @@ import ( "bytes" "container/heap" "fmt" + "math" "github.com/apache/arrow/go/v14/arrow" "github.com/apache/arrow/go/v14/arrow/array" @@ -14,11 +15,15 @@ import ( // MergeRecords merges the given records. The records must all have the same // schema. orderByCols is a slice of indexes into the columns that the records -// and resulting records are ordered by. Note that the given records should -// already be ordered by the given columns. +// and resulting records are ordered by. While ordering the limit is checked before appending more rows. +// If limit is less than 1, no limit is applied. +// Note that the given records should already be ordered by the given columns. // WARNING: Only ascending ordering is currently supported. func MergeRecords( - mem memory.Allocator, records []arrow.Record, orderByCols []int, + mem memory.Allocator, + records []arrow.Record, + orderByCols []int, + limit uint64, ) (arrow.Record, error) { h := cursorHeap{ cursors: make([]cursor, len(records)), @@ -32,8 +37,13 @@ func MergeRecords( recordBuilder := builder.NewRecordBuilder(mem, schema) defer recordBuilder.Release() + if limit == 0 { + limit = math.MaxInt64 + } + count := uint64(0) + heap.Init(&h) - for h.Len() > 0 { + for h.Len() > 0 && count < limit { // Minimum cursor is always at index 0. r := h.cursors[0].r i := h.cursors[0].curIdx @@ -45,10 +55,12 @@ func MergeRecords( if int64(i+1) >= r.NumRows() { // Pop the cursor since it has no more data. _ = heap.Pop(&h) + count++ continue } h.cursors[0].curIdx++ heap.Fix(&h, 0) + count++ } return recordBuilder.NewRecord(), nil diff --git a/pqarrow/arrowutils/merge_test.go b/pqarrow/arrowutils/merge_test.go index 082f85a95..3fd041e9c 100644 --- a/pqarrow/arrowutils/merge_test.go +++ b/pqarrow/arrowutils/merge_test.go @@ -35,7 +35,7 @@ func TestMerge(t *testing.T) { record3 := array.NewRecord(schema, []arrow.Array{a}, int64(a.Len())) res, err := arrowutils.MergeRecords( - memory.DefaultAllocator, []arrow.Record{record1, record2, record3}, []int{0}, + memory.DefaultAllocator, []arrow.Record{record1, record2, record3}, []int{0}, 0, ) require.NoError(t, err) require.Equal(t, int64(1), res.NumCols()) @@ -47,4 +47,11 @@ func TestMerge(t *testing.T) { for i := 1; i < col.Len(); i++ { require.Equal(t, expected[i-1], col.Value(i)) } + + // check that we can merge with a limit + res, err = arrowutils.MergeRecords( + memory.DefaultAllocator, []arrow.Record{record1, record2, record3}, []int{0}, 3, + ) + require.NoError(t, err) + require.Equal(t, int64(3), res.NumRows()) } diff --git a/query/physicalplan/ordered_aggregate.go b/query/physicalplan/ordered_aggregate.go index ad973ec90..5f6f0ebfe 100644 --- a/query/physicalplan/ordered_aggregate.go +++ b/query/physicalplan/ordered_aggregate.go @@ -486,7 +486,7 @@ func (a *OrderedAggregate) Finish(ctx context.Context) error { for i := range orderByCols { orderByCols[i] = i } - mergedRecord, err := arrowutils.MergeRecords(a.pool, records, orderByCols) + mergedRecord, err := arrowutils.MergeRecords(a.pool, records, orderByCols, 0) if err != nil { return err } diff --git a/query/physicalplan/ordered_synchronizer.go b/query/physicalplan/ordered_synchronizer.go index ccf7abed4..444b6301c 100644 --- a/query/physicalplan/ordered_synchronizer.go +++ b/query/physicalplan/ordered_synchronizer.go @@ -117,7 +117,7 @@ func (o *OrderedSynchronizer) mergeRecordsLocked() (arrow.Record, error) { if err := o.ensureSameSchema(o.sync.data); err != nil { return nil, err } - mergedRecord, err := arrowutils.MergeRecords(o.pool, o.sync.data, o.orderByCols) + mergedRecord, err := arrowutils.MergeRecords(o.pool, o.sync.data, o.orderByCols, 0) if err != nil { return nil, err }