diff --git a/Asynchrone/Source/Sequence/SharedAsyncSequence.swift b/Asynchrone/Source/Sequence/SharedAsyncSequence.swift index e73ab1c..a12f85a 100644 --- a/Asynchrone/Source/Sequence/SharedAsyncSequence.swift +++ b/Asynchrone/Source/Sequence/SharedAsyncSequence.swift @@ -32,124 +32,97 @@ import Foundation /// let values = try await self.stream.collect() /// // ... /// ``` -public struct SharedAsyncSequence: AsyncSequence { - public typealias AsyncIterator = AsyncThrowingStream.Iterator +public struct SharedAsyncSequence: AsyncSequence { + public typealias AsyncIterator = AsyncThrowingStream.Iterator /// The kind of elements streamed. - public typealias Element = T.Element + public typealias Element = Base.Element - // MARK: SharedAsyncSequence (Private Properties) - - private let inner: Inner + // Private + private let manager: SubSequenceManager // MARK: SharedAsyncSequence (Public Properties) /// Creates a shareable async sequence that can be used across multiple tasks. /// - Parameters: /// - base: The async sequence in which this sequence receives it's elements. - public init(_ base: T) { - self.inner = Inner(base) + public init(_ base: Base) { + self.manager = SubSequenceManager(base) } // MARK: AsyncSequence /// Creates an async iterator that emits elements of this async sequence. /// - Returns: An instance that conforms to `AsyncIteratorProtocol`. - public func makeAsyncIterator() -> AsyncThrowingStream.Iterator { - inner.makeAsyncIterator() + public func makeAsyncIterator() -> AsyncThrowingStream.Iterator { + self.manager.makeAsyncIterator() } } -// MARK: - SharedAsyncSequence > Inner - -extension SharedAsyncSequence { - - fileprivate final class Inner { - - fileprivate typealias Element = T.Element +// MARK: Sub sequence manager - // MARK: Inner (Private Properties) +fileprivate actor SubSequenceManager{ + + fileprivate typealias Element = Base.Element - private var base: T - - private let lock = NSLock() - private var streams: [AsyncThrowingStream] = [] - private var continuations: [AsyncThrowingStream.Continuation] = [] - private var subscriptionTask: Task? + // Private + private var base: Base + private var sequences: [ThrowingPassthroughAsyncSequence] = [] + private var subscriptionTask: Task? - // MARK: Initialization + // MARK: Initialization - fileprivate init(_ base: T) { - self.base = base - } - - deinit { - subscriptionTask?.cancel() + fileprivate init(_ base: Base) { + self.base = base + } + + deinit { + self.subscriptionTask?.cancel() + } + + // MARK: API + + /// Creates an new stream and returns its async iterator that emits elements of base async sequence. + /// - Returns: An instance that conforms to `AsyncIteratorProtocol`. + nonisolated fileprivate func makeAsyncIterator() -> ThrowingPassthroughAsyncSequence.AsyncIterator { + let sequence = ThrowingPassthroughAsyncSequence() + Task { [sequence] in + await self.add(sequence: sequence) } - // MARK: API - - /// Creates an new stream and returns its async iterator that emits elements of base async sequence. - /// - Returns: An instance that conforms to `AsyncIteratorProtocol`. - fileprivate func makeAsyncIterator() -> AsyncThrowingStream.Iterator { - var streamContinuation: AsyncThrowingStream.Continuation! - let stream = AsyncThrowingStream { (continuation: AsyncThrowingStream.Continuation) in - streamContinuation = continuation - } - - add(stream: stream, continuation: streamContinuation) - - return stream.makeAsyncIterator() - } - - // MARK: Inner (Private Methods) - - private func add( - stream: AsyncThrowingStream, - continuation: AsyncThrowingStream.Continuation - ) { - modify { - streams.append(stream) - continuations.append(continuation) - subscribeToBaseStreamIfNeeded() - } - } + return sequence.makeAsyncIterator() + } - private func modify(_ block: () -> Void) { - lock.lock() - block() - lock.unlock() - } + // MARK: Sequence management - private func subscribeToBaseStreamIfNeeded() { - guard subscriptionTask == nil else { return } + private func add(sequence: ThrowingPassthroughAsyncSequence) { + self.sequences.append(sequence) + self.subscribeToBaseSequenceIfNeeded() + } + + private func subscribeToBaseSequenceIfNeeded() { + guard self.subscriptionTask == nil else { return } - subscriptionTask = Task { [weak self, base] in - guard let self = self else { return } + self.subscriptionTask = Task { [weak self, base] in + guard let self = self else { return } - guard !Task.isCancelled else { - self.modify { - self.continuations.forEach { $0.finish(throwing: CancellationError()) } - } - return + guard !Task.isCancelled else { + await self.sequences.forEach { + $0.finish(throwing: CancellationError()) } + return + } - do { - for try await value in base { - self.modify { - self.continuations.forEach { $0.yield(value) } - } - } - self.modify { - self.continuations.forEach { $0.finish(throwing: nil) } - } - } catch { - self.modify { - self.continuations.forEach { $0.finish(throwing: error) } - } + do { + for try await value in base { + await self.sequences.forEach { $0.yield(value) } } + + await self.sequences.forEach { $0.finish() } + } catch { + await self.sequences.forEach { $0.finish(throwing: error) } } } } diff --git a/AsynchroneTests/SharedAsyncSequenceTests.swift b/AsynchroneTests/SharedAsyncSequenceTests.swift index bf298a8..e35dd3f 100644 --- a/AsynchroneTests/SharedAsyncSequenceTests.swift +++ b/AsynchroneTests/SharedAsyncSequenceTests.swift @@ -43,7 +43,7 @@ final class SharedAsyncSequenceTests: XCTestCase { XCTAssertEqual(values[2], "abc") XCTAssertEqual(values[3], "abcd") } - + let values = try await self.stream.collect() XCTAssertEqual(values.count, 4) XCTAssertEqual(values[0], "a")