// // AsyncBufferedChannel.swift // // // Created by Thibault Wittemberg on 07/01/2022. // import DequeModule import OrderedCollections /// A channel for sending elements from one task to another. /// /// The `AsyncBufferedChannel` class is intended to be used as a communication type between tasks, /// particularly when one task produces values and another task consumes those values. The values are /// buffered awaiting a consumer to consume them from iteration. /// `finish()` induces a terminal state and no further elements can be sent. /// /// ```swift /// let channel = AsyncBufferedChannel() /// /// Task { /// for await element in channel { /// print(element) // will print 1, 2, 3 /// } /// } /// /// sut.send(1) /// sut.send(2) /// sut.send(3) /// sut.finish() /// ``` public final class AsyncBufferedChannel: AsyncSequence, Sendable { public typealias Element = Element public typealias AsyncIterator = Iterator struct Awaiting: Hashable { let id: Int let continuation: UnsafeContinuation? static func placeHolder(id: Int) -> Awaiting { Awaiting(id: id, continuation: nil) } func hash(into hasher: inout Hasher) { hasher.combine(self.id) } static func == (lhs: Awaiting, rhs: Awaiting) -> Bool { lhs.id == rhs.id } } enum SendDecision { case resume(Awaiting, Element) case terminate([Awaiting]) case nothing } enum AwaitingDecision { case resume(Element?) case suspend } enum Value { case element(Element) case termination } enum State: @unchecked Sendable { case idle case queued(Deque) case awaiting(OrderedSet) case finished static var initial: State { .idle } } let ids: ManagedCriticalState let state: ManagedCriticalState public init() { self.ids = ManagedCriticalState(0) self.state = ManagedCriticalState(.initial) } func generateId() -> Int { self.ids.withCriticalRegion { ids in ids += 1 return ids } } var hasBufferedElements: Bool { self.state.withCriticalRegion { state in switch state { case .idle: return false case .queued(let values) where !values.isEmpty: return true case .awaiting, .queued: return false case .finished: return true } } } func send(_ value: Value) { let decision = self.state.withCriticalRegion { state -> SendDecision in switch (state, value) { case (.idle, .element): state = .queued([value]) return .nothing case (.idle, .termination): state = .finished return .nothing case (.queued(var values), _): values.append(value) state = .queued(values) return .nothing case (.awaiting(var awaitings), .element(let element)): let awaiting = awaitings.removeFirst() if awaitings.isEmpty { state = .idle } else { state = .awaiting(awaitings) } return .resume(awaiting, element) case (.awaiting(let awaitings), .termination): state = .finished return .terminate(Array(awaitings)) case (.finished, _): return .nothing } } switch decision { case .nothing: break case .terminate(let awaitings): awaitings.forEach { $0.continuation?.resume(returning: nil) } case let .resume(awaiting, element): awaiting.continuation?.resume(returning: element) } } public func send(_ element: Element) { self.send(.element(element)) } public func finish() { self.send(.termination) } func next(onSuspend: (() -> Void)? = nil) async -> Element? { let awaitingId = self.generateId() let cancellation = ManagedCriticalState(false) return await withTaskCancellationHandler { [state] in let awaiting = state.withCriticalRegion { state -> Awaiting? in cancellation.withCriticalRegion { cancellation in cancellation = true } switch state { case .awaiting(var awaitings): let awaiting = awaitings.remove(.placeHolder(id: awaitingId)) if awaitings.isEmpty { state = .idle } else { state = .awaiting(awaitings) } return awaiting default: return nil } } awaiting?.continuation?.resume(returning: nil) } operation: { await withUnsafeContinuation { [state] (continuation: UnsafeContinuation) in let decision = state.withCriticalRegion { state -> AwaitingDecision in let isCancelled = cancellation.withCriticalRegion { $0 } guard !isCancelled else { return .resume(nil) } switch state { case .idle: state = .awaiting([Awaiting(id: awaitingId, continuation: continuation)]) return .suspend case .queued(var values): let value = values.popFirst() switch value { case .termination: state = .finished return .resume(nil) case .element(let element) where !values.isEmpty: state = .queued(values) return .resume(element) case .element(let element): state = .idle return .resume(element) default: state = .idle return .suspend } case .awaiting(var awaitings): awaitings.updateOrAppend(Awaiting(id: awaitingId, continuation: continuation)) state = .awaiting(awaitings) return .suspend case .finished: return .resume(nil) } } switch decision { case .resume(let element): continuation.resume(returning: element) case .suspend: onSuspend?() } } } } public func makeAsyncIterator() -> AsyncIterator { Iterator( channel: self ) } public struct Iterator: AsyncIteratorProtocol, Sendable { let channel: AsyncBufferedChannel var hasBufferedElements: Bool { self.channel.hasBufferedElements } public func next() async -> Element? { await self.channel.next() } } }