Skip to content

Commit

Permalink
Expose ChannelReader.ReadAllAsync from non-core builds as well (dotne…
Browse files Browse the repository at this point in the history
…t#94417)

* Expose ChannelReader.ReadAllAsync from non-core builds as well

I'm not sure why we didn't fix this before when we shipped Microsoft.Bcl.AsyncInterfaces, but with IAsyncEnumerable available downlevel, there's no need to hide this method away; it can be in all builds.  Doing so makes it easier for others to create their own channel implementations, as they don't _need_ to multitarget in order to override everything they might want to.

I've not changed any C# code, just moved it between files.

* Address PR feedback
  • Loading branch information
stephentoub authored Nov 14, 2023
1 parent 81393b7 commit ab2d63b
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 325 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ protected ChannelReader() { }
public virtual System.Threading.Tasks.Task Completion { get { throw null; } }
public virtual int Count { get { throw null; } }
public virtual System.Threading.Tasks.ValueTask<T> ReadAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Collections.Generic.IAsyncEnumerable<T> ReadAllAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual bool TryPeek([System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out T item) { throw null; }
public abstract bool TryRead([System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out T item);
public abstract System.Threading.Tasks.ValueTask<bool> WaitToReadAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
<PackageReference Include="System.Threading.Tasks.Extensions" Version="$(SystemThreadingTasksExtensionsVersion)" />
<ProjectReference Include="$(LibrariesProjectRoot)Microsoft.Bcl.AsyncInterfaces\ref\Microsoft.Bcl.AsyncInterfaces.csproj" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,4 @@ public partial class ChannelClosedException : System.InvalidOperationException
#endif
protected ChannelClosedException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
}
public abstract partial class ChannelReader<T>
{
public virtual System.Collections.Generic.IAsyncEnumerable<T> ReadAllAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'netstandard2.1'))" />
<Compile Include="System\Threading\Channels\ChannelOptions.cs" />
<Compile Include="System\Threading\Channels\ChannelReader.cs" />
<Compile Include="System\Threading\Channels\ChannelReader.netcoreapp.cs"
Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'netstandard2.1'))" />
<Compile Include="System\Threading\Channels\ChannelUtilities.cs" />
<Compile Include="System\Threading\Channels\ChannelWriter.cs" />
<Compile Include="System\Threading\Channels\Channel_1.cs" />
Expand All @@ -40,7 +38,7 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
<Compile Include="$(CommonPath)Internal\Padding.cs"
Link="Common\Internal\Padding.cs" />
<Compile Include="$(CommonPath)System\Collections\Concurrent\IProducerConsumerQueue.cs"
Link="Common\System\Collections\Concurrent\IProducerConsumerQueue.cs" />
Link="Common\System\Collections\Concurrent\IProducerConsumerQueue.cs" />
<Compile Include="$(CommonPath)System\Collections\Concurrent\MultiProducerMultiConsumerQueue.cs"
Link="Common\System\Collections\Concurrent\MultiProducerMultiConsumerQueue.cs" />
<Compile Include="$(CommonPath)System\Collections\Concurrent\SingleProducerSingleConsumerQueue.cs"
Expand All @@ -61,6 +59,7 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>

<ItemGroup Condition="!$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'netstandard2.1'))">
<PackageReference Include="System.Threading.Tasks.Extensions" Version="$(SystemThreadingTasksExtensionsVersion)" />
<ProjectReference Include="$(LibrariesProjectRoot)Microsoft.Bcl.AsyncInterfaces\src\Microsoft.Bcl.AsyncInterfaces.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

namespace System.Threading.Channels
Expand Down Expand Up @@ -90,5 +92,23 @@ async ValueTask<T> ReadAsyncCore(CancellationToken ct)
}
}
}

/// <summary>Creates an <see cref="IAsyncEnumerable{T}"/> that enables reading all of the data from the channel.</summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to use to cancel the enumeration.</param>
/// <remarks>
/// Each <see cref="IAsyncEnumerator{T}.MoveNextAsync"/> call that returns <c>true</c> will read the next item out of the channel.
/// <see cref="IAsyncEnumerator{T}.MoveNextAsync"/> will return false once no more data is or will ever be available to read.
/// </remarks>
/// <returns>The created async enumerable.</returns>
public virtual async IAsyncEnumerable<T> ReadAllAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
while (await WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
while (TryRead(out T? item))
{
yield return item;
}
}
}
}
}

This file was deleted.

266 changes: 266 additions & 0 deletions src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,272 @@ public async Task ReadAsync_ConsecutiveReadsSucceed()
}
}

[Fact]
public void ReadAllAsync_NotIdempotent()
{
Channel<int> c = CreateChannel();
IAsyncEnumerable<int> e = c.Reader.ReadAllAsync();
Assert.NotNull(e);
Assert.NotSame(e, c.Reader.ReadAllAsync());
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ReadAllAsync_UseMoveNextAsyncAfterCompleted_ReturnsFalse(bool completeWhilePending)
{
Channel<int> c = CreateChannel();
IAsyncEnumerator<int> e = c.Reader.ReadAllAsync().GetAsyncEnumerator();

ValueTask<bool> vt;
if (completeWhilePending)
{
c.Writer.Complete();
vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.False(vt.Result);
}
else
{
vt = e.MoveNextAsync();
Assert.False(vt.IsCompleted);
c.Writer.Complete();
Assert.False(await vt);
}

vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.False(vt.Result);
}

[Fact]
public void ReadAllAsync_AvailableDataCompletesSynchronously()
{
Channel<int> c = CreateChannel();

IAsyncEnumerator<int> e = c.Reader.ReadAllAsync().GetAsyncEnumerator();
try
{
for (int i = 100; i < 110; i++)
{
Assert.True(c.Writer.TryWrite(i));
ValueTask<bool> vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.True(vt.Result);
Assert.Equal(i, e.Current);
}
}
finally
{
ValueTask vt = e.DisposeAsync();
Assert.True(vt.IsCompletedSuccessfully);
vt.GetAwaiter().GetResult();
}
}

[Fact]
public async Task ReadAllAsync_UnavailableDataCompletesAsynchronously()
{
Channel<int> c = CreateChannel();

IAsyncEnumerator<int> e = c.Reader.ReadAllAsync().GetAsyncEnumerator();
try
{
for (int i = 100; i < 110; i++)
{
ValueTask<bool> vt = e.MoveNextAsync();
Assert.False(vt.IsCompleted);
Task producer = Task.Run(() => c.Writer.TryWrite(i));
Assert.True(await vt);
await producer;
Assert.Equal(i, e.Current);
}
}
finally
{
ValueTask vt = e.DisposeAsync();
Assert.True(vt.IsCompletedSuccessfully);
vt.GetAwaiter().GetResult();
}
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(128)]
public async Task ReadAllAsync_ProducerConsumer_ConsumesAllData(int items)
{
Channel<int> c = CreateChannel();

int producedTotal = 0, consumedTotal = 0;
await Task.WhenAll(
Task.Run(async () =>
{
for (int i = 0; i < items; i++)
{
await c.Writer.WriteAsync(i);
producedTotal += i;
}
c.Writer.Complete();
}),
Task.Run(async () =>
{
IAsyncEnumerator<int> e = c.Reader.ReadAllAsync().GetAsyncEnumerator();
try
{
while (await e.MoveNextAsync())
{
consumedTotal += e.Current;
}
}
finally
{
await e.DisposeAsync();
}
}));

Assert.Equal(producedTotal, consumedTotal);
}

[Fact]
public async Task ReadAllAsync_MultipleEnumerationsToEnd()
{
Channel<int> c = CreateChannel();

Assert.True(c.Writer.TryWrite(42));
c.Writer.Complete();

IAsyncEnumerable<int> enumerable = c.Reader.ReadAllAsync();
IAsyncEnumerator<int> e = enumerable.GetAsyncEnumerator();

Assert.True(await e.MoveNextAsync());
Assert.Equal(42, e.Current);

Assert.False(await e.MoveNextAsync());
Assert.False(await e.MoveNextAsync());

await e.DisposeAsync();

e = enumerable.GetAsyncEnumerator();
Assert.Same(enumerable, e);

Assert.False(await e.MoveNextAsync());
Assert.False(await e.MoveNextAsync());
}

[Theory]
[InlineData(false, false)]
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public void ReadAllAsync_MultipleSingleElementEnumerations_AllItemsEnumerated(bool sameEnumerable, bool dispose)
{
Channel<int> c = CreateChannel();
IAsyncEnumerable<int> enumerable = c.Reader.ReadAllAsync();

for (int i = 0; i < 10; i++)
{
Assert.True(c.Writer.TryWrite(i));
IAsyncEnumerator<int> e = (sameEnumerable ? enumerable : c.Reader.ReadAllAsync()).GetAsyncEnumerator();
ValueTask<bool> vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.True(vt.Result);
Assert.Equal(i, e.Current);
if (dispose)
{
ValueTask dvt = e.DisposeAsync();
Assert.True(dvt.IsCompletedSuccessfully);
dvt.GetAwaiter().GetResult();
}
}
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ReadAllAsync_DualConcurrentEnumeration_AllItemsEnumerated(bool sameEnumerable)
{
if (RequiresSingleReader)
{
return;
}

Channel<int> c = CreateChannel();

IAsyncEnumerable<int> enumerable = c.Reader.ReadAllAsync();

IAsyncEnumerator<int> e1 = enumerable.GetAsyncEnumerator();
IAsyncEnumerator<int> e2 = (sameEnumerable ? enumerable : c.Reader.ReadAllAsync()).GetAsyncEnumerator();
Assert.NotSame(e1, e2);

ValueTask<bool> vt1, vt2;
int producerTotal = 0, consumerTotal = 0;
for (int i = 0; i < 10; i++)
{
vt1 = e1.MoveNextAsync();
vt2 = e2.MoveNextAsync();

await c.Writer.WriteAsync(i);
producerTotal += i;
await c.Writer.WriteAsync(i * 2);
producerTotal += i * 2;

Assert.True(await vt1);
Assert.True(await vt2);
consumerTotal += e1.Current;
consumerTotal += e2.Current;
}

vt1 = e1.MoveNextAsync();
vt2 = e2.MoveNextAsync();
c.Writer.Complete();
Assert.False(await vt1);
Assert.False(await vt2);

Assert.Equal(producerTotal, consumerTotal);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ReadAllAsync_CanceledBeforeMoveNextAsync_Throws(bool dataAvailable)
{
Channel<int> c = CreateChannel();
if (dataAvailable)
{
Assert.True(c.Writer.TryWrite(42));
}

var cts = new CancellationTokenSource();
cts.Cancel();

IAsyncEnumerator<int> e = c.Reader.ReadAllAsync(cts.Token).GetAsyncEnumerator();
ValueTask<bool> vt = e.MoveNextAsync();
Assert.True(vt.IsCompleted);
Assert.False(vt.IsCompletedSuccessfully);
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await vt);
Assert.Equal(cts.Token, oce.CancellationToken);
}

[Fact]
public async Task ReadAllAsync_CanceledAfterMoveNextAsync_Throws()
{
Channel<int> c = CreateChannel();
var cts = new CancellationTokenSource();

IAsyncEnumerator<int> e = c.Reader.ReadAllAsync(cts.Token).GetAsyncEnumerator();
ValueTask<bool> vt = e.MoveNextAsync();
Assert.False(vt.IsCompleted);

cts.Cancel();
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await vt);
Assert.Equal(cts.Token, oce.CancellationToken);

vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.False(vt.Result);
}

[Fact]
public async Task WaitToReadAsync_ConsecutiveReadsSucceed()
{
Expand Down
Loading

0 comments on commit ab2d63b

Please sign in to comment.