Skip to content

Commit

Permalink
Update SortedMerge to use a PriorityQueue (#657)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielearwicker authored May 4, 2024
1 parent ef180fd commit e89f29d
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 57 deletions.
107 changes: 78 additions & 29 deletions Source/SuperLinq.Async/SortedMerge.cs
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,10 @@ params IAsyncEnumerable<TSource>[] otherSequences
// Private implementation method that performs a merge of multiple, ordered sequences using
// a precedence function which encodes order-sensitive comparison logic based on the caller's arguments.
//
// The algorithm employed in this implementation is not necessarily the most optimal way to merge
// two sequences. A swap-compare version would probably be somewhat more efficient - but at the
// expense of considerably more complexity. One possible optimization would be to detect that only
// a single sequence remains (all other being consumed) and break out of the main while-loop and
// simply yield the items that are part of the final sequence.
// Where available, PriorityQueue is used to maintain the remaining enumerators ordered by the
// value of the Current property.
//
// The algorithm used here will perform N*(K1+K2+...Kn-1) comparisons, where <c>N => otherSequences.Count()+1.</c>
// Otherwise, a sorted array is used, with BinarySearch finding the re-insert location.

static async IAsyncEnumerable<TSource> Impl(
IEnumerable<IAsyncEnumerable<TSource>> sequences,
Expand All @@ -291,39 +288,91 @@ static async IAsyncEnumerable<TSource> Impl(
[EnumeratorCancellation] CancellationToken cancellationToken = default
)
{
var list = await EnumeratorList<TSource>.Create(sequences, cancellationToken).ConfigureAwait(false);
await using var ignored_ = list.ConfigureAwait(false);
var enumerators = new List<IAsyncEnumerator<TSource>>();

// prime all of the iterators by advancing them to their first element (if any)
for (var i = 0; await list.MoveNext(i).ConfigureAwait(false); i++)
{ }

// while all iterators have not yet been consumed...
while (list.Any())
try
{
var nextIndex = 0;
var nextValue = list.Current(0);
var nextKey = keySelector(nextValue);
// Ensure we dispose first N enumerators if N+1 throws
foreach (var sequence in sequences)
{
var e = sequence.GetAsyncEnumerator(cancellationToken);
if (await e.MoveNextAsync())
enumerators.Add(e);
else
await e.DisposeAsync();
}
#if NET6_0_OR_GREATER
var queue = new PriorityQueue<IAsyncEnumerator<TSource>, TKey>(
enumerators.Select(x => (x, keySelector(x.Current))),
comparer);

#pragma warning disable CA2000 // e will be disposed via enumerators list
while (queue.TryDequeue(out var e, out var _))
#pragma warning restore CA2000 // Dispose objects before losing scope
{
yield return e.Current;

// Fast drain of final enumerator
if (queue.Count == 0)
{
while (await e.MoveNextAsync()) yield return e.Current;
break;
}

if (await e.MoveNextAsync()) queue.Enqueue(e, keySelector(e.Current));
}

#else
enumerators.Sort((x, y) => comparer.Compare(keySelector(x.Current), keySelector(y.Current)));

// find the next least element to return
for (var i = 1; i < list.Count; i++)
var arr = enumerators.ToArray();
var count = arr.Length;
var sourceComparer = new SourceComparer<TSource, TKey>(comparer, keySelector);

while (count > 1)
{
var anotherElement = list.Current(i);
var anotherKey = keySelector(anotherElement);
// determine which element follows based on ordering function
if (comparer.Compare(nextKey, anotherKey) > 0)
var e = arr[0];
yield return e.Current;

if (!await e.MoveNextAsync())
{
nextIndex = i;
nextValue = anotherElement;
nextKey = anotherKey;
count--;
Array.Copy(arr, 1, arr, 0, count);
continue;
}

var index = Array.BinarySearch(arr, 1, count - 1, e, sourceComparer);
if (index < 0) index = ~index;

index--;
if (index > 0) Array.Copy(arr, 1, arr, 0, index);
arr[index] = e;
}

yield return nextValue; // next value in precedence order
if (count == 1)
{
var e = arr[0];
yield return e.Current;

// advance iterator that yielded element, excluding it when consumed
_ = await list.MoveNextOnce(nextIndex).ConfigureAwait(false);
while (await e.MoveNextAsync()) yield return e.Current;
}
#endif
}
finally
{
foreach (var e in enumerators) await e.DisposeAsync();
}
}
}

#if !NET6_0_OR_GREATER
internal sealed class SourceComparer<TItem, TKey>(
IComparer<TKey> keyComparer,
Func<TItem, TKey> keySelector
) : IComparer<IAsyncEnumerator<TItem>>
{
public int Compare(IAsyncEnumerator<TItem>? x, IAsyncEnumerator<TItem>? y)
=> keyComparer.Compare(keySelector(x!.Current), keySelector(y!.Current));
}
#endif
}
107 changes: 79 additions & 28 deletions Source/SuperLinq/SortedMerge.cs
Original file line number Diff line number Diff line change
Expand Up @@ -355,48 +355,99 @@ public static IEnumerable<TSource> SortedMergeBy<TSource, TKey>(
// Private implementation method that performs a merge of multiple, ordered sequences using
// a precedence function which encodes order-sensitive comparison logic based on the caller's arguments.
//
// The algorithm employed in this implementation is not necessarily the most optimal way to merge
// two sequences. A swap-compare version would probably be somewhat more efficient - but at the
// expense of considerably more complexity. One possible optimization would be to detect that only
// a single sequence remains (all other being consumed) and break out of the main while-loop and
// simply yield the items that are part of the final sequence.
// Where available, PriorityQueue is used to maintain the remaining enumerators ordered by the
// value of the Current property.
//
// The algorithm used here will perform N*(K1+K2+...Kn-1) comparisons, where <c>N => otherSequences.Count()+1.</c>
// Otherwise, a sorted array is used, with BinarySearch finding the re-insert location.

static IEnumerable<TSource> Impl(IEnumerable<IEnumerable<TSource>> sequences, Func<TSource, TKey> keySelector, IComparer<TKey> comparer)
{
using var list = new EnumeratorList<TSource>(sequences);
var enumerators = new List<IEnumerator<TSource>>();

// prime all of the iterators by advancing them to their first element (if any)
for (var i = 0; list.MoveNext(i); i++)
{ }

// while all iterators have not yet been consumed...
while (list.Any())
try
{
var nextIndex = 0;
var nextValue = list.Current(0);
var nextKey = keySelector(nextValue);
// Ensure we dispose first N enumerators if N+1 throws
foreach (var sequence in sequences)
{
var e = sequence.GetEnumerator();
if (e.MoveNext())
enumerators.Add(e);
else
e.Dispose();
}

// find the next least element to return
for (var i = 1; i < list.Count; i++)
#if NET6_0_OR_GREATER
var queue = new PriorityQueue<IEnumerator<TSource>, TKey>(
enumerators.Select(x => (x, keySelector(x.Current))),
comparer);

#pragma warning disable CA2000 // e will be disposed via enumerators list
while (queue.TryDequeue(out var e, out var _))
#pragma warning restore CA2000 // Dispose objects before losing scope
{
var anotherElement = list.Current(i);
var anotherKey = keySelector(anotherElement);
// determine which element follows based on ordering function
if (comparer.Compare(nextKey, anotherKey) > 0)
yield return e.Current;

// Fast drain of final enumerator
if (queue.Count == 0)
{
nextIndex = i;
nextValue = anotherElement;
nextKey = anotherKey;
while (e.MoveNext()) yield return e.Current;
break;
}

if (e.MoveNext()) queue.Enqueue(e, keySelector(e.Current));
}

yield return nextValue; // next value in precedence order
#else
enumerators.Sort((x, y) => comparer.Compare(keySelector(x.Current), keySelector(y.Current)));

var arr = enumerators.ToArray();
var count = arr.Length;
var sourceComparer = new SourceComparer<TSource, TKey>(comparer, keySelector);

while (count > 1)
{
var e = arr[0];
yield return e.Current;

if (!e.MoveNext())
{
count--;
Array.Copy(arr, 1, arr, 0, count);
continue;
}

// advance iterator that yielded element, excluding it when consumed
_ = list.MoveNextOnce(nextIndex);
var index = Array.BinarySearch(arr, 1, count - 1, e, sourceComparer);
if (index < 0) index = ~index;

index--;
if (index > 0) Array.Copy(arr, 1, arr, 0, index);
arr[index] = e;
}

if (count == 1)
{
var e = arr[0];
yield return e.Current;

while (e.MoveNext()) yield return e.Current;
}
#endif
}
finally
{
foreach (var e in enumerators) e.Dispose();
}
}
}

#if !NET6_0_OR_GREATER
internal sealed class SourceComparer<TItem, TKey>(
IComparer<TKey> keyComparer,
Func<TItem, TKey> keySelector
) : IComparer<IEnumerator<TItem>>
{
public int Compare(IEnumerator<TItem>? x, IEnumerator<TItem>? y)
=> keyComparer.Compare(keySelector(x!.Current), keySelector(y!.Current));
}
#endif
}
14 changes: 14 additions & 0 deletions Tests/SuperLinq.Async.Test/SortedMergeByTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,18 @@ public async Task TestSortedMergeByCustomComparer()

await result.AssertSequenceEqual(expectedResult);
}

/// <summary>
/// Verify that SortedMerge correctly merges sequences with overlapping contents.
/// </summary>
[Fact]
public async Task TestSortedMergeOverlappingSequences()
{
await using var sequenceA = TestingSequence.Of(1, 3, 5, 7, 9, 11);
await using var sequenceB = TestingSequence.Of(1, 4, 5, 10, 12);
await using var sequenceC = TestingSequence.Of(2, 4, 6, 8, 10, 12);

var result = sequenceA.SortedMergeBy(SuperEnumerable.Identity, sequenceB, sequenceC);
await result.AssertSequenceEqual([1, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 10, 11, 12, 12]);
}
}
14 changes: 14 additions & 0 deletions Tests/SuperLinq.Async.Test/SortedMergeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,18 @@ public async Task TestSortedMergeCustomComparer()

await result.AssertSequenceEqual(expectedResult);
}

/// <summary>
/// Verify that SortedMerge correctly merges sequences with overlapping contents.
/// </summary>
[Fact]
public async Task TestSortedMergeOverlappingSequences()
{
await using var sequenceA = TestingSequence.Of(1, 3, 5, 7, 9, 11);
await using var sequenceB = TestingSequence.Of(1, 4, 5, 10, 12);
await using var sequenceC = TestingSequence.Of(2, 4, 6, 8, 10, 12);

var result = sequenceA.SortedMerge(sequenceB, sequenceC);
await result.AssertSequenceEqual([1, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 10, 11, 12, 12]);
}
}
14 changes: 14 additions & 0 deletions Tests/SuperLinq.Test/SortedMergeByTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,18 @@ public void TestSortedMergeByCustomComparer()

result.AssertSequenceEqual(expectedResult);
}

/// <summary>
/// Verify that SortedMergeBy correctly merges sequences with overlapping contents.
/// </summary>
[Fact]
public void TestSortedMergeByOverlappingSequences()
{
using var sequenceA = TestingSequence.Of(1, 3, 5, 7, 9, 11);
using var sequenceB = TestingSequence.Of(1, 4, 5, 10, 12);
using var sequenceC = TestingSequence.Of(2, 4, 6, 8, 10, 12);

var result = sequenceA.SortedMergeBy(SuperEnumerable.Identity, sequenceB, sequenceC);
result.AssertSequenceEqual([1, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 10, 11, 12, 12]);
}
}
14 changes: 14 additions & 0 deletions Tests/SuperLinq.Test/SortedMergeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,18 @@ public void TestSortedMergeCustomComparer()

result.AssertSequenceEqual(expectedResult);
}

/// <summary>
/// Verify that SortedMerge correctly merges sequences with overlapping contents.
/// </summary>
[Fact]
public void TestSortedMergeOverlappingSequences()
{
using var sequenceA = TestingSequence.Of(1, 3, 5, 7, 9, 11);
using var sequenceB = TestingSequence.Of(1, 4, 5, 10, 12);
using var sequenceC = TestingSequence.Of(2, 4, 6, 8, 10, 12);

var result = sequenceA.SortedMerge(sequenceB, sequenceC);
result.AssertSequenceEqual([1, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 10, 11, 12, 12]);
}
}

0 comments on commit e89f29d

Please sign in to comment.