diff --git a/Amplify/Core/Support/JSONValue+Subscript.swift b/Amplify/Core/Support/JSONValue+Subscript.swift index 0b331106a8..65c72d570b 100644 --- a/Amplify/Core/Support/JSONValue+Subscript.swift +++ b/Amplify/Core/Support/JSONValue+Subscript.swift @@ -26,4 +26,8 @@ public extension JSONValue { return nil } } + + subscript(dynamicMember member: String) -> JSONValue? { + self[member] + } } diff --git a/Amplify/Core/Support/JSONValue.swift b/Amplify/Core/Support/JSONValue.swift index 4ece18e76e..afb1a243ad 100644 --- a/Amplify/Core/Support/JSONValue.swift +++ b/Amplify/Core/Support/JSONValue.swift @@ -8,6 +8,7 @@ import Foundation /// A utility type that allows us to represent an arbitrary JSON structure +@dynamicMemberLookup public enum JSONValue { case array([JSONValue]) case boolean(Bool) @@ -105,3 +106,62 @@ extension JSONValue: ExpressibleByStringLiteral { self = .string(value) } } + +extension JSONValue { + + public var asObject: [String: JSONValue]? { + if case .object(let object) = self { + return object + } + + return nil + } + + public var asArray: [JSONValue]? { + if case .array(let array) = self { + return array + } + + return nil + } + + public var stringValue: String? { + if case .string(let string) = self { + return string + } + + return nil + } + + public var intValue: Int? { + if case .number(let double) = self, + double < Double(Int.max) && double >= Double(Int.min) { + return Int(double) + } + return nil + } + + public var doubleValue: Double? { + if case .number(let double) = self { + return double + } + + return nil + } + + public var booleanValue: Bool? { + if case .boolean(let bool) = self { + return bool + } + + return nil + } + + public var isNull: Bool { + if case .null = self { + return true + } + + return false + } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/APIError+Unauthorized.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/APIError+Unauthorized.swift index a9f7c00bd0..7c8db5ff4f 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/APIError+Unauthorized.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/APIError+Unauthorized.swift @@ -6,7 +6,6 @@ // import Amplify -import AppSyncRealTimeClient extension APIError { diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Configure.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Configure.swift index d39a0b3375..8ea423d995 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Configure.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Configure.swift @@ -7,7 +7,6 @@ import Amplify import AWSPluginsCore -import AppSyncRealTimeClient import AwsCommonRuntimeKit public extension AWSAPIPlugin { @@ -53,14 +52,14 @@ extension AWSAPIPlugin { struct ConfigurationDependencies { let authService: AWSAuthServiceBehavior let pluginConfig: AWSAPICategoryPluginConfiguration - let subscriptionConnectionFactory: SubscriptionConnectionFactory + let appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol let logLevel: Amplify.LogLevel init( configurationValues: JSONValue, apiAuthProviderFactory: APIAuthProviderFactory, authService: AWSAuthServiceBehavior? = nil, - subscriptionConnectionFactory: SubscriptionConnectionFactory? = nil, + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol? = nil, logLevel: Amplify.LogLevel? = nil ) throws { let authService = authService @@ -72,15 +71,13 @@ extension AWSAPIPlugin { authService: authService ) - let subscriptionConnectionFactory = subscriptionConnectionFactory - ?? AWSSubscriptionConnectionFactory() - let logLevel = logLevel ?? Amplify.Logging.logLevel self.init( pluginConfig: pluginConfig, authService: authService, - subscriptionConnectionFactory: subscriptionConnectionFactory, + appSyncRealTimeClientFactory: appSyncRealTimeClientFactory + ?? AppSyncRealTimeClientFactory(), logLevel: logLevel ) } @@ -88,12 +85,12 @@ extension AWSAPIPlugin { init( pluginConfig: AWSAPICategoryPluginConfiguration, authService: AWSAuthServiceBehavior, - subscriptionConnectionFactory: SubscriptionConnectionFactory, + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol, logLevel: Amplify.LogLevel ) { self.pluginConfig = pluginConfig self.authService = authService - self.subscriptionConnectionFactory = subscriptionConnectionFactory + self.appSyncRealTimeClientFactory = appSyncRealTimeClientFactory self.logLevel = logLevel } @@ -108,8 +105,6 @@ extension AWSAPIPlugin { func configure(using dependencies: ConfigurationDependencies) { authService = dependencies.authService pluginConfig = dependencies.pluginConfig - subscriptionConnectionFactory = dependencies.subscriptionConnectionFactory - AppSyncRealTimeClient.logLevel = AppSyncRealTimeClient.LogLevel( - rawValue: dependencies.logLevel.rawValue) ?? .error + appSyncRealTimeClientFactory = dependencies.appSyncRealTimeClientFactory } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+GraphQLBehavior.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+GraphQLBehavior.swift index 6948f9195d..ae2d999245 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+GraphQLBehavior.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+GraphQLBehavior.swift @@ -61,7 +61,7 @@ public extension AWSAPIPlugin { let operation = AWSGraphQLSubscriptionOperation( request: request.toOperationRequest(operationType: .subscription), pluginConfig: pluginConfig, - subscriptionConnectionFactory: subscriptionConnectionFactory, + appSyncRealTimeClientFactory: appSyncRealTimeClientFactory, authService: authService, apiAuthProviderFactory: authProviderFactory, inProcessListener: valueListener, @@ -74,7 +74,7 @@ public extension AWSAPIPlugin { let request = request.toOperationRequest(operationType: .subscription) let runner = AWSGraphQLSubscriptionTaskRunner(request: request, pluginConfig: pluginConfig, - subscriptionConnectionFactory: subscriptionConnectionFactory, + appSyncClientFactory: appSyncRealTimeClientFactory, authService: authService, apiAuthProviderFactory: authProviderFactory) return runner.sequence diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Log.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Log.swift index 1165e272ce..f17a43a25e 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Log.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Log.swift @@ -6,7 +6,6 @@ // import Amplify -import AppSyncRealTimeClient extension AWSAPIPlugin { var log: Logger { diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Resettable.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Resettable.swift index f2644ba97f..b40fb6f4a3 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Resettable.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin+Resettable.swift @@ -11,6 +11,11 @@ import Foundation extension AWSAPIPlugin: Resettable { public func reset() async { + if let resettableAppSyncRealClientFactory = appSyncRealTimeClientFactory as? Resettable { + await resettableAppSyncRealClientFactory.reset() + } + appSyncRealTimeClientFactory = nil + mapper.reset() await session.cancelAndReset() @@ -24,8 +29,6 @@ extension AWSAPIPlugin: Resettable { reachabilityMapLock.execute { reachabilityMap.removeAll() } - - subscriptionConnectionFactory = nil } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin.swift index 5fe085b42e..ce124f1f54 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AWSAPIPlugin.swift @@ -36,7 +36,7 @@ final public class AWSAPIPlugin: NSObject, APICategoryPlugin, APICategoryGraphQL /// Creating and retrieving connections for subscriptions. This will be instantiated during the configuration phase, /// and is clearable by `reset()`. This is implicitly unwrapped to be destroyed when resetting. - var subscriptionConnectionFactory: SubscriptionConnectionFactory! + var appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol! var authProviderFactory: APIAuthProviderFactory diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeClient+HandleRequest.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeClient+HandleRequest.swift new file mode 100644 index 0000000000..d3eee7a753 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeClient+HandleRequest.swift @@ -0,0 +1,111 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Combine +import Amplify + +extension AppSyncRealTimeClient { + /** + Submit an AppSync request to real-time server. + - Returns: + Void indicates request is finished successfully + - Throws: + Error is throwed when request is failed + */ + func sendRequest( + _ request: AppSyncRealTimeRequest, + timeout: TimeInterval = 5 + ) async throws { + var responseSubscriptions = Set() + try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation) in + guard let self else { + Self.log.debug("[AppSyncRealTimeClient] client has already been disposed") + continuation.resume(returning: ()) + return + } + + // listen to response + self.subject + .setFailureType(to: AppSyncRealTimeRequest.Error.self) + .flatMap { Self.filterResponse(request: request, response: $0) } + .timeout(.seconds(timeout), scheduler: DispatchQueue.global(qos: .userInitiated), customError: { .timeout }) + .first() + .sink(receiveCompletion: { completion in + switch completion { + case .finished: + continuation.resume(returning: ()) + case .failure(let error): + continuation.resume(throwing: error) + } + }, receiveValue: { _ in }) + .store(in: &responseSubscriptions) + + // sending request; error is discarded and will be classified as timeout + Task { + do { + let decoratedRequest = await self.requestInterceptor.interceptRequest( + event: request, + url: self.endpoint + ) + let requestJSON = String(data: try Self.jsonEncoder.encode(decoratedRequest), encoding: .utf8) + + try await self.webSocketClient.write(message: requestJSON!) + } catch { + Self.log.debug("[AppSyncRealTimeClient]Failed to send AppSync request \(request), error: \(error)") + } + } + } + } + + private static func filterResponse( + request: AppSyncRealTimeRequest, + response: AppSyncRealTimeResponse + ) -> AnyPublisher { + let justTheResponse = Just(response) + .setFailureType(to: AppSyncRealTimeRequest.Error.self) + .eraseToAnyPublisher() + + switch (request, response.type) { + case (.connectionInit, .connectionAck): + return justTheResponse + + case (.start(let startRequest), .startAck) where startRequest.id == response.id: + return justTheResponse + + case (.stop(let id), .stopAck) where id == response.id: + return justTheResponse + + case (_, .error) + where request.id != nil + && request.id == response.id + && response.payload?.errors != nil: + let errorsJson: JSONValue = (response.payload?.errors)! + let errors = errorsJson.asArray ?? [errorsJson] + let reqeustErrors = errors.compactMap(AppSyncRealTimeRequest.parseResponseError(error:)) + if reqeustErrors.isEmpty { + return Empty( + outputType: AppSyncRealTimeResponse.self, + failureType: AppSyncRealTimeRequest.Error.self + ).eraseToAnyPublisher() + } else { + return Fail( + outputType: AppSyncRealTimeResponse.self, + failure: reqeustErrors.first! + ).eraseToAnyPublisher() + } + + default: + return Empty( + outputType: AppSyncRealTimeResponse.self, + failureType: AppSyncRealTimeRequest.Error.self + ).eraseToAnyPublisher() + + } + } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeClient.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeClient.swift new file mode 100644 index 0000000000..c8bf7efcab --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeClient.swift @@ -0,0 +1,464 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Amplify +import Combine +@_spi(WebSocket) import AWSPluginsCore + +/** + The AppSyncRealTimeClient conforms to the AppSync real-time WebSocket protocol. + ref: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html + */ +actor AppSyncRealTimeClient: AppSyncRealTimeClientProtocol { + + static let jsonEncoder = JSONEncoder() + static let jsonDecoder = JSONDecoder() + + enum State { + case none + case connecting + case connected + case connectionDropped + case disconnecting + case disconnected + } + + /// Internal state for tracking AppSync connection + private let state = CurrentValueSubject(.none) + /// Subscriptions created using this client + private var subscriptions = [String: AppSyncRealTimeSubscription]() + /// heart beat stream to keep connection alive + private let heartBeats = PassthroughSubject() + /// Cancellables bind to instance life cycle + private var cancellables = Set() + /// Cancellables bind to connection life cycle + private var cancellablesBindToConnection = Set() + + /// AppSync RealTime server endpoint + internal let endpoint: URL + /// Interceptor for decorating AppSyncRealTimeRequest + internal let requestInterceptor: AppSyncRequestInterceptor + + /// WebSocketClient offering connections at the WebSocket protocol level + internal var webSocketClient: AppSyncWebSocketClientProtocol + /// Writable data stream convert WebSocketEvent to AppSyncRealTimeResponse + internal let subject = PassthroughSubject() + + var isConnected: Bool { + self.state.value == .connected + } + + /** + Creates a new AppSyncRealTimeClient with endpoint, requestInterceptor and webSocketClient. + - Parameters: + - endpoint: AppSync real-time server endpoint + - requestInterceptor: Interceptor for decocating AppSyncRealTimeRequest + - webSocketClient: WebSocketClient for reading/writing to connection + */ + init( + endpoint: URL, + requestInterceptor: AppSyncRequestInterceptor, + webSocketClient: AppSyncWebSocketClientProtocol + ) { + self.endpoint = endpoint + self.requestInterceptor = requestInterceptor + + self.webSocketClient = webSocketClient + + Task { await self.subscribeToWebSocketEvent() } + } + + deinit { + log.debug("Deinit AppSyncRealTimeClient") + subject.send(completion: .finished) + cancellables = Set() + cancellablesBindToConnection = Set() + } + + /** + Connecting to remote AppSync real-time server. + */ + func connect() async throws { + switch self.state.value { + case .connecting, .connected: + log.debug("[AppSyncRealTimeClient] client is already connecting or connected") + return + case .disconnecting: + try await waitForState(.disconnected) + case .connectionDropped, .disconnected, .none: + break + } + + guard self.state.value != .connecting else { + log.debug("[AppSyncRealTimeClient] actor reentry, state has been changed to connecting") + return + } + + self.state.send(.connecting) + log.debug("[AppSyncRealTimeClient] client start connecting") + + try await RetryWithJitter.execute { [weak self] in + guard let self else { return } + await self.webSocketClient.connect( + autoConnectOnNetworkStatusChange: true, + autoRetryOnConnectionFailure: true + ) + try await self.sendRequest(.connectionInit) + } + } + + /** + Disconnect only when there are no subscriptions exist. + */ + func disconnectWhenIdel() async { + if self.subscriptions.isEmpty { + log.debug("[AppSyncRealTimeClient] no subscription exist, client is trying to disconnect") + await disconnect() + } else { + log.debug("[AppSyncRealTimeClient] client only try to disconnect when no subscriptions exist") + } + } + + /** + Disconnect from AppSync real-time server. + */ + func disconnect() async { + guard self.state.value != .disconnecting else { + log.debug("[AppSyncRealTimeClient] client already disconnecting") + return + } + + defer { self.state.send(.disconnected) } + + log.debug("[AppSyncRealTimeClient] client start disconnecting") + self.state.send(.disconnecting) + self.cancellablesBindToConnection = Set() + await self.webSocketClient.disconnect() + log.debug("[AppSyncRealTimeClient] client is disconnected") + } + + /** + Subscribing to a query with unique identifier. + - Parameters: + - id: unique identifier + - query: GraphQL query for subscription + + - Returns: + A never fail data stream for AppSyncSubscriptionEvent. + */ + func subscribe(id: String, query: String) async throws -> AnyPublisher { + log.debug("[AppSyncRealTimeClient] Received subscription request id: \(id), query: \(query)") + let subscription = AppSyncRealTimeSubscription(id: id, query: query, appSyncRealTimeClient: self) + subscriptions[id] = subscription + + + // Placing the actual subscription work in a deferred task and + // promptly returning the filtered publisher for downstream consumption of all error messages. + defer { + Task { [weak self] in + guard let self = self else { return } + if !(await self.isConnected) { + try await connect() + try await waitForState(.connected) + } + await self.bindCancellableToConnection(try await self.startSubscription(id)) + }.toAnyCancellable.store(in: &cancellablesBindToConnection) + } + + return filterAppSyncSubscriptionEvent(with: id) + .merge(with: (await subscription.publisher).toAppSyncSubscriptionEventStream()) + .eraseToAnyPublisher() + } + + private func waitForState(_ targetState: State) async throws { + var cancellables = Set() + + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) -> Void in + state.filter { $0 == targetState } + .setFailureType(to: AppSyncRealTimeRequest.Error.self) + .timeout(.seconds(10), scheduler: DispatchQueue.global()) + .first() + .sink { completion in + switch completion { + case .finished: + continuation.resume(returning: ()) + case .failure(let error): + continuation.resume(throwing: error) + } + } receiveValue: { _ in } + .store(in: &cancellables) + } + } + + /** + Unsubscribe a subscription with unique identifier. + - Parameters: + - id: unique identifier of the subscription. + */ + func unsubscribe(id: String) async throws { + defer { + log.debug("[AppSyncRealTimeClient] deleted subscription with id: \(id)") + subscriptions.removeValue(forKey: id) + } + + guard let subscription = subscriptions[id] else { + log.debug("[AppSyncRealTimeClient] start subscription failed, could not found subscription with id \(id) ") + return + } + log.debug("[AppSyncRealTimeClient] unsubscribing: \(id)") + try await subscription.unsubscribe() + } + + private func startSubscription(_ id: String) async throws -> AnyCancellable { + guard let subscription = subscriptions[id] else { + log.debug("[AppSyncRealTimeClient] start subscription failed, could not found subscription with id \(id) ") + throw APIError.unknown("Could not find a subscription with id \(id)", "", nil) + } + + try await subscription.subscribe() + + return AnyCancellable { + Task { + try await subscription.unsubscribe() + } + } + + } + + private func subscribeToWebSocketEvent() async { + await self.webSocketClient.publisher.sink { [weak self] _ in + self?.log.debug("[AppSyncRealTimeClient] WebSocketClient terminated") + } receiveValue: { webSocketEvent in + Task { [weak self] in + await self?.onWebSocketEvent(webSocketEvent) + }.toAnyCancellable.store(in: &self.cancellables) + } + .store(in: &cancellables) + } + + private func resumeExistingSubscriptions() { + log.debug("[AppSyncRealTimeClient] Resuming existing subscriptions") + for (id, _) in self.subscriptions { + Task { + do { + try await self.startSubscription(id).store(in: &cancellablesBindToConnection) + } catch { + log.debug("[AppSyncRealTimeClient] Failed to resume existing subscription with id: (\(id))") + } + } + } + + } + + nonisolated private func writeAppSyncEvent(_ event: AppSyncRealTimeRequest) async throws { + guard await self.webSocketClient.isConnected else { + log.debug("[AppSyncRealTimeClient] Attempting to write to a webSocket haven't been connected.") + return + } + + let interceptedEvent = await self.requestInterceptor.interceptRequest(event: event, url: self.endpoint) + let eventString = try String(data: Self.jsonEncoder.encode(interceptedEvent), encoding: .utf8)! + log.debug("[AppSyncRealTimeClient] Writing AppSyncEvent \(eventString)") + try await webSocketClient.write(message: eventString) + } + + /** + Filter response to downstream by id. + - Parameters: + - id: subscription identifier + - Returns: + - AppSyncSubscriptionEvent data stream related to subscription + - important: connection errors will also be passed to downstreams + */ + private func filterAppSyncSubscriptionEvent( + with id: String + ) -> AnyPublisher { + subject.filter { $0.id == id || $0.type == .connectionError } + .map { response -> AppSyncSubscriptionEvent? in + switch response.type { + case .connectionError, .error: + return .error(Self.decodeAppSyncRealTimeResponseError(response.payload)) + case .data: + return response.payload.map { .data($0) } + default: + return nil + } + } + .compactMap { $0 } + .eraseToAnyPublisher() + } + + private static func decodeAppSyncRealTimeResponseError(_ data: JSONValue?) -> [Error] { + let knownAppSyncRealTimeRequestErorrs = + Self.decodeAppSyncRealTimeRequestError(data) + .filter { !$0.isUnknown } + if knownAppSyncRealTimeRequestErorrs.isEmpty { + let graphQLErrors = Self.decodeGraphQLErrors(data) + return graphQLErrors.isEmpty + ? [APIError.operationError("Failed to decode AppSync error response", "", nil)] + : graphQLErrors + } else { + return knownAppSyncRealTimeRequestErorrs + } + } + + private static func decodeGraphQLErrors(_ data: JSONValue?) -> [GraphQLError] { + do { + return try GraphQLErrorDecoder.decodeAppSyncErrors(data) + } catch { + log.debug("[AppSyncRealTimeClient] Failed to decode errors: \(error)") + return [] + } + } + + private static func decodeAppSyncRealTimeRequestError(_ data: JSONValue?) -> [AppSyncRealTimeRequest.Error] { + guard let errorsJson = data?.errors else { + log.error("[AppSyncRealTimeClient] No 'errors' field found in response json") + return [] + } + let errors = errorsJson.asArray ?? [errorsJson] + return errors.compactMap(AppSyncRealTimeRequest.parseResponseError(error:)) + } + + private func bindCancellableToConnection(_ cancellable: AnyCancellable) { + cancellable.store(in: &cancellablesBindToConnection) + } + +} + +// MARK: - On WebSocket Events +extension AppSyncRealTimeClient { + private func onWebSocketEvent(_ event: WebSocketEvent) { + log.debug("[AppSyncRealTimeClient] Received websocket event \(event)") + switch event { + case .connected: + log.debug("[AppSyncRealTimeClient] WebSocket connected") + if self.state.value == .connectionDropped { + log.debug("[AppSyncRealTimeClient] reconnecting appSyncClient after connection drop") + Task { [weak self] in + try? await self?.connect() + }.toAnyCancellable.store(in: &cancellablesBindToConnection) + } + + case let .disconnected(closeCode, reason): // + log.debug("[AppSyncRealTimeClient] WebSocket disconnected with closeCode: \(closeCode), reason: \(String(describing: reason))") + if self.state.value != .disconnecting || self.state.value != .disconnected { + self.state.send(.connectionDropped) + } + self.cancellablesBindToConnection = Set() + + case .error(let error): + // Since we've activated auto-reconnect functionality in WebSocketClient upon connection failure, + // we only record errors here for debugging purposes. + log.debug("[AppSyncRealTimeClient] WebSocket error event: \(error)") + case .string(let string): + guard let data = string.data(using: .utf8) else { + log.debug("[AppSyncRealTimeClient] Failed to decode string \(string)") + return + } + guard let response = try? Self.jsonDecoder.decode(AppSyncRealTimeResponse.self, from: data) else { + log.debug("[AppSyncRealTimeClient] Failed to decode string to AppSync event") + return + } + self.onAppSyncRealTimeResponse(response) + + case .data(let data): + guard let response = try? Self.jsonDecoder.decode(AppSyncRealTimeResponse.self, from: data) else { + log.debug("[AppSyncRealTimeClient] Failed to decode data to AppSync event") + return + } + self.onAppSyncRealTimeResponse(response) + } + } + +} + +// MARK: - On AppSyncServer Event +extension AppSyncRealTimeClient { + /// handles connection level response and passes request level response to downstream + private func onAppSyncRealTimeResponse(_ event: AppSyncRealTimeResponse) { + switch event.type { + case .connectionAck: + log.debug("[AppSyncRealTimeClient] AppSync connected: \(String(describing: event.payload))") + subject.send(event) + + self.resumeExistingSubscriptions() + self.state.send(.connected) + self.monitorHeartBeats(event.payload) + + case .keepAlive: + self.heartBeats.send(()) + + default: + log.debug("[AppSyncRealTimeClient] AppSync received response: \(event)") + subject.send(event) + } + } + + private func monitorHeartBeats(_ connectionAck: JSONValue?) { + let timeoutMs = connectionAck?.connectionTimeoutMs?.intValue ?? 0 + log.debug("[AppSyncRealTimeClient] Starting heart beat monitor with interval \(timeoutMs) ms") + heartBeats.eraseToAnyPublisher() + .debounce(for: .milliseconds(timeoutMs), scheduler: DispatchQueue.global()) + .first() + .sink(receiveValue: { + self.log.debug("[AppSyncRealTimeClient] KeepAlive timed out, disconnecting") + Task { [weak self] in + await self?.disconnect() + }.toAnyCancellable.store(in: &self.cancellables) + }) + .store(in: &cancellablesBindToConnection) + // start counting down + heartBeats.send(()) + } +} + +extension Publisher where Output == AppSyncRealTimeSubscription.State, Failure == Never { + func toAppSyncSubscriptionEventStream() -> AnyPublisher { + self.compactMap { subscriptionState -> AppSyncSubscriptionEvent? in + switch subscriptionState { + case .subscribing: return .subscribing + case .subscribed: return .subscribed + case .unsubscribed: return .unsubscribed + default: return nil + } + } + .eraseToAnyPublisher() + } +} + +extension AppSyncRealTimeClient: DefaultLogger { + static var log: Logger { + Amplify.Logging.logger(forCategory: CategoryType.api.displayName, forNamespace: String(describing: self)) + } + + nonisolated var log: Logger { Self.log } +} + +extension AppSyncRealTimeClient: Resettable { + func reset() async { + subject.send(completion: .finished) + cancellables = Set() + cancellablesBindToConnection = Set() + + if let resettableWebSocketClient = webSocketClient as? Resettable { + await resettableWebSocketClient.reset() + } + } +} + +fileprivate extension Task { + var toAnyCancellable: AnyCancellable { + AnyCancellable { + if !self.isCancelled { + self.cancel() + } + } + } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequest.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequest.swift new file mode 100644 index 0000000000..19599820b4 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequest.swift @@ -0,0 +1,128 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Combine +import Amplify + +public enum AppSyncRealTimeRequest { + case connectionInit + case start(StartRequest) + case stop(String) + + public struct StartRequest { + let id: String + let data: String + let auth: AppSyncRealTimeRequestAuth? + } + + var id: String? { + switch self { + case let .start(request): return request.id + case let .stop(id): return id + default: return nil + } + } +} + +extension AppSyncRealTimeRequest: Encodable { + enum CodingKeys: CodingKey { + case type + case payload + case id + } + + enum PayloadCodingKeys: CodingKey { + case data + case extensions + } + + enum ExtensionsCodingKeys: CodingKey { + case authorization + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .connectionInit: + try container.encode("connection_init", forKey: .type) + case .start(let startRequest): + try container.encode("start", forKey: .type) + try container.encode(startRequest.id, forKey: .id) + + let payloadEncoder = container.superEncoder(forKey: .payload) + var payloadContainer = payloadEncoder.container(keyedBy: PayloadCodingKeys.self) + try payloadContainer.encode(startRequest.data, forKey: .data) + + let extensionEncoder = payloadContainer.superEncoder(forKey: .extensions) + var extensionContainer = extensionEncoder.container(keyedBy: ExtensionsCodingKeys.self) + try extensionContainer.encodeIfPresent(startRequest.auth, forKey: .authorization) + case .stop(let id): + try container.encode("stop", forKey: .type) + try container.encode(id, forKey: .id) + } + } +} + + +extension AppSyncRealTimeRequest { + public enum Error: Swift.Error, Equatable { + case timeout + case limitExceeded + case maxSubscriptionsReached + case unauthorized + case unknown(message: String? = nil, causedBy: Swift.Error? = nil, payload: [String: Any]?) + + var isUnknown: Bool { + if case .unknown = self { + return true + } + return false + } + + public static func == (lhs: AppSyncRealTimeRequest.Error, rhs: AppSyncRealTimeRequest.Error) -> Bool { + switch (lhs, rhs) { + case (.timeout, .timeout), + (.limitExceeded, .limitExceeded), + (.maxSubscriptionsReached, .maxSubscriptionsReached), + (.unauthorized, .unauthorized): + return true + default: + return false + } + } + } + + + public static func parseResponseError( + error: JSONValue + ) -> AppSyncRealTimeRequest.Error? { + let limitExceededErrorString = "LimitExceededError" + let maxSubscriptionsReachedErrorString = "MaxSubscriptionsReachedError" + let unauthorized = "Unauthorized" + + guard let errorType = error.errorType?.stringValue else { + return nil + } + + switch errorType { + case _ where errorType.contains(limitExceededErrorString): + return .limitExceeded + case _ where errorType.contains(maxSubscriptionsReachedErrorString): + return .maxSubscriptionsReached + case _ where errorType.contains(unauthorized): + return .unauthorized + default: + return .unknown( + message: error.message?.stringValue, + causedBy: nil, + payload: error.asObject + ) + } + } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift new file mode 100644 index 0000000000..87e01b1842 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift @@ -0,0 +1,127 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +public enum AppSyncRealTimeRequestAuth { + case authToken(AuthToken) + case apiKey(ApiKey) + case iam(IAM) + + public struct AuthToken { + let host: String + let authToken: String + } + + public struct ApiKey { + let host: String + let apiKey: String + let amzDate: String + } + + public struct IAM { + let host: String + let authToken: String + let securityToken: String + let amzDate: String + } + + public struct URLQuery { + let header: AppSyncRealTimeRequestAuth + let payload: String + + init(header: AppSyncRealTimeRequestAuth, payload: String = "{}") { + self.header = header + self.payload = payload + } + + func withBaseURL(_ url: URL, encoder: JSONEncoder? = nil) -> URL { + let jsonEncoder: JSONEncoder = encoder ?? JSONEncoder() + guard let headerJsonData = try? jsonEncoder.encode(header) else { + return url + } + + guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) + else { + return url + } + + urlComponents.queryItems = [ + URLQueryItem(name: "header", value: headerJsonData.base64EncodedString()), + URLQueryItem(name: "payload", value: try? payload.base64EncodedString()) + ] + + return urlComponents.url ?? url + } + } +} + +extension AppSyncRealTimeRequestAuth: Encodable { + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .apiKey(let apiKey): + try container.encode(apiKey) + case .authToken(let cognito): + try container.encode(cognito) + case .iam(let iam): + try container.encode(iam) + } + } +} + +extension AppSyncRealTimeRequestAuth.AuthToken: Encodable { + enum CodingKeys: String, CodingKey { + case host + case authToken = "Authorization" + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(host, forKey: .host) + try container.encode(authToken, forKey: .authToken) + } +} + +extension AppSyncRealTimeRequestAuth.ApiKey: Encodable { + enum CodingKeys: String, CodingKey { + case host + case apiKey = "x-api-key" + case amzDate = "x-amz-date" + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(host, forKey: .host) + try container.encode(apiKey, forKey: .apiKey) + try container.encode(amzDate, forKey: .amzDate) + } +} + +extension AppSyncRealTimeRequestAuth.IAM: Encodable { + enum CodingKeys: String, CodingKey { + case host + case accept + case contentType = "content-type" + case authToken = "Authorization" + case securityToken = "X-Amz-Security-Token" + case contentEncoding = "content-encoding" + case amzDate = "x-amz-date" + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(host, forKey: .host) + try container.encode("application/json, text/javascript", forKey: .accept) + try container.encode("application/json; charset=UTF-8", forKey: .contentType) + try container.encode("amz-1.0", forKey: .contentEncoding) + try container.encode(securityToken, forKey: .securityToken) + try container.encode(authToken, forKey: .authToken) + try container.encode(amzDate, forKey: .amzDate) + } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeResponse.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeResponse.swift new file mode 100644 index 0000000000..dfec371035 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeResponse.swift @@ -0,0 +1,30 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation +import Amplify + +public struct AppSyncRealTimeResponse { + + public let id: String? + public let payload: JSONValue? + public let type: EventType + + public enum EventType: String, Codable { + case connectionAck = "connection_ack" + case startAck = "start_ack" + case stopAck = "complete" + case data + case error + case connectionError = "connection_error" + case keepAlive = "ka" + case starting + } +} + +extension AppSyncRealTimeResponse: Decodable { +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeSubscription.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeSubscription.swift new file mode 100644 index 0000000000..d7e4c6ef42 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeSubscription.swift @@ -0,0 +1,129 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Combine +import Amplify +@_spi(WebSocket) import AWSPluginsCore + +/** + AppSyncRealTimeSubscription reprensents one realtime subscription to AppSync realtime server. + */ +actor AppSyncRealTimeSubscription { + static let jsonEncoder = JSONEncoder() + + enum State { + case none + case subscribing + case subscribed + case unsubscribing + case unsubscribed + case failure + } + + /// internal state for tracking subscription status + private let state = CurrentValueSubject(.none) + + /// publisher for monitoring subscription status + public var publisher: AnyPublisher { + state.eraseToAnyPublisher() + } + + private weak var appSyncRealTimeClient: AppSyncRealTimeClient? + + public let id: String + public let query: String + + + init(id: String, query: String, appSyncRealTimeClient: AppSyncRealTimeClient) { + self.id = id + self.query = query + self.appSyncRealTimeClient = appSyncRealTimeClient + } + + deinit { + self.state.send(completion: .finished) + } + + func subscribe() async throws { + guard self.state.value != .subscribing else { + log.debug("[AppSyncRealTimeSubscription-\(id)] Subscription already in subscribing state") + return + } + + guard self.state.value != .subscribed else { + log.debug("[AppSyncRealTimeSubscription-\(id)] Subscription already in subscribed state") + return + } + + log.debug("[AppSyncRealTimeSubscription-\(id)] Start subscribing") + self.state.send(.subscribing) + + do { + try await RetryWithJitter.execute(shouldRetryOnError: { error in + (error as? AppSyncRealTimeRequest.Error) == .maxSubscriptionsReached + }) { [weak self] in + guard let self else { return } + try await self.appSyncRealTimeClient?.sendRequest( + .start(.init(id: self.id, data: self.query, auth: nil)) + ) + } + } catch { + log.debug("[AppSyncRealTimeSubscription-\(id)] Failed to subscribe, error: \(error)") + self.state.send(.failure) + throw error + } + + log.debug("[AppSyncRealTimeSubscription-\(id)] Subscribed") + self.state.send(.subscribed) + } + + func unsubscribe() async throws { + guard self.state.value == .subscribed else { + log.debug("[AppSyncRealTimeSubscription-\(id)] Subscription should be subscribed to be unsubscribed") + return + } + + log.debug("[AppSyncRealTimeSubscription-\(id)] Unsubscribing") + self.state.send(.unsubscribing) + + do { + let request = AppSyncRealTimeRequest.stop(id) + try await appSyncRealTimeClient?.sendRequest(request) + } catch { + log.debug("[AppSyncRealTimeSubscription-\(id)] Failed to unsubscribe, error \(error)") + self.state.send(.failure) + throw error + } + + log.debug("[AppSyncRealTimeSubscription-\(id)] Unsubscribed") + self.state.send(.unsubscribed) + } + + private static func sendAppSyncRealTimeRequest( + _ request: AppSyncRealTimeRequest, + with webSocketClient: AppSyncWebSocketClientProtocol + ) async throws { + guard let requestJson = try String( + data: Self.jsonEncoder.encode(request), + encoding: .utf8 + ) else { + return + } + + try await webSocketClient.write(message: requestJson) + } +} + +extension AppSyncRealTimeSubscription: DefaultLogger { + static var log: Logger { + Amplify.Logging.logger(forCategory: CategoryType.api.displayName, forNamespace: String(describing: self)) + } + + nonisolated var log: Logger { Self.log } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRequestInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRequestInterceptor.swift new file mode 100644 index 0000000000..92414ea28c --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRequestInterceptor.swift @@ -0,0 +1,13 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +protocol AppSyncRequestInterceptor { + func interceptRequest(event: AppSyncRealTimeRequest, url: URL) async -> AppSyncRealTimeRequest +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncSubscriptionEvent.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncSubscriptionEvent.swift new file mode 100644 index 0000000000..ec86c53e6a --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncSubscriptionEvent.swift @@ -0,0 +1,18 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Amplify + +public enum AppSyncSubscriptionEvent { + case subscribing + case subscribed + case data(JSONValue) + case unsubscribed + case error([Error]) +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncWebSocketClientProtocol.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncWebSocketClientProtocol.swift new file mode 100644 index 0000000000..d7d9cadc29 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncWebSocketClientProtocol.swift @@ -0,0 +1,28 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Combine +@_spi(WebSocket) import AWSPluginsCore + +protocol AppSyncWebSocketClientProtocol: AnyObject { + var isConnected: Bool { get async } + var publisher: AnyPublisher { get async } + + func connect( + autoConnectOnNetworkStatusChange: Bool, + autoRetryOnConnectionFailure: Bool + ) async + + func disconnect() async + + func write(message: String) async throws +} + +extension WebSocketClient: AppSyncWebSocketClientProtocol { } + diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift new file mode 100644 index 0000000000..f52ded490e --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift @@ -0,0 +1,59 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Amplify +@_spi(WebSocket) import AWSPluginsCore + +class APIKeyAuthInterceptor { + private let apiKey: String + private let getAuthHeader = authHeaderBuilder() + + init(apiKey: String) { + self.apiKey = apiKey + } + +} + +extension APIKeyAuthInterceptor: WebSocketInterceptor { + func interceptConnection(url: URL) async -> URL { + let authHeader = getAuthHeader(apiKey, AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!) + return AppSyncRealTimeRequestAuth.URLQuery( + header: .apiKey(authHeader) + ).withBaseURL(url) + } +} + +extension APIKeyAuthInterceptor: AppSyncRequestInterceptor { + func interceptRequest(event: AppSyncRealTimeRequest, url: URL) async -> AppSyncRealTimeRequest { + let host = AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host! + guard case .start(let request) = event else { + return event + } + return .start(.init( + id: request.id, + data: request.data, + auth: .apiKey(getAuthHeader(apiKey, host)) + )) + } +} + +fileprivate func authHeaderBuilder() -> (String, String) -> AppSyncRealTimeRequestAuth.ApiKey { + let formatter = DateFormatter() + formatter.timeZone = TimeZone(secondsFromGMT: 0) + formatter.locale = Locale(identifier: "en_US_POSIX") + formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'" + return { apiKey, host in + AppSyncRealTimeRequestAuth.ApiKey( + host: host, + apiKey: apiKey, + amzDate: formatter.string(from: Date()) + ) + } + +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift new file mode 100644 index 0000000000..b0f19ffd78 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift @@ -0,0 +1,80 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation +import Amplify +@_spi(WebSocket) import AWSPluginsCore + +/// General purpose authenticatication subscriptions interceptor for providers whose only +/// requirement is to provide an authentication token via the "Authorization" header +class AuthTokenInterceptor { + + let getLatestAuthToken: () async throws -> String? + + init(getLatestAuthToken: @escaping () async throws -> String?) { + self.getLatestAuthToken = getLatestAuthToken + } + + init(authTokenProvider: AmplifyAuthTokenProvider) { + self.getLatestAuthToken = authTokenProvider.getLatestAuthToken + } + + private func getAuthToken() async -> AmplifyAuthTokenProvider.AuthToken { + // A user that is not signed in should receive an unauthorized error from + // the connection attempt. This code achieves this by always creating a valid + // request to AppSync even when the token cannot be retrieved. The request sent + // to AppSync will receive a response indicating the request is unauthorized. + // If we do not use empty token string and perform the remaining logic of the + // request construction then it will fail request validation at AppSync before + // the authorization check, which ends up being propagated back to the caller + // as a "bad request". Example of bad requests are when the header and payload + // query strings are missing or when the data is not base64 encoded. + (try? await getLatestAuthToken()) ?? "" + } +} + +extension AuthTokenInterceptor: AppSyncRequestInterceptor { + func interceptRequest(event: AppSyncRealTimeRequest, url: URL) async -> AppSyncRealTimeRequest { + guard case .start(let request) = event else { + return event + } + + let authToken = await getAuthToken() + + return .start(.init( + id: request.id, + data: request.data, + auth: .authToken(.init( + host: AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!, + authToken: authToken + )) + )) + } +} + +extension AuthTokenInterceptor: WebSocketInterceptor { + func interceptConnection(url: URL) async -> URL { + let authToken = await getAuthToken() + + return AppSyncRealTimeRequestAuth.URLQuery( + header: .authToken(.init( + host: AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!, + authToken: authToken + )) + ).withBaseURL(url) + } +} + +// MARK: AuthorizationTokenAuthInterceptor + DefaultLogger +extension AuthTokenInterceptor: DefaultLogger { + public static var log: Logger { + Amplify.Logging.logger(forCategory: CategoryType.api.displayName, forNamespace: String(describing: self)) + } + public var log: Logger { + Self.log + } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthenticationTokenAuthInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthenticationTokenAuthInterceptor.swift deleted file mode 100644 index 02028de7db..0000000000 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthenticationTokenAuthInterceptor.swift +++ /dev/null @@ -1,108 +0,0 @@ -// -// Copyright Amazon.com Inc. or its affiliates. -// All Rights Reserved. -// -// SPDX-License-Identifier: Apache-2.0 -// - -import Foundation -import AppSyncRealTimeClient -import Amplify - -/// General purpose authenticatication subscriptions interceptor for providers whose only -/// requirement is to provide an authentication token via the "Authorization" header -class AuthenticationTokenAuthInterceptor: AuthInterceptorAsync { - - let authTokenProvider: AmplifyAuthTokenProvider - - init(authTokenProvider: AmplifyAuthTokenProvider) { - self.authTokenProvider = authTokenProvider - } - - func interceptMessage(_ message: AppSyncMessage, for endpoint: URL) async -> AppSyncMessage { - let host = endpoint.host! - guard let authToken = await getAuthToken() else { - log.warn("Missing authentication token for subscription") - return message - } - - guard case .subscribe = message.messageType else { - return message - } - - let authHeader = TokenAuthHeader(token: authToken, host: host) - var payload = message.payload ?? AppSyncMessage.Payload() - payload.authHeader = authHeader - - let signedMessage = AppSyncMessage( - id: message.id, - payload: payload, - type: message.messageType - ) - return signedMessage - } - - func interceptConnection( - _ request: AppSyncConnectionRequest, - for endpoint: URL - ) async -> AppSyncConnectionRequest { - let host = endpoint.host! - guard let authToken = await getAuthToken() else { - log.warn("Missing authentication token for subscription request") - return request - } - - let authHeader = TokenAuthHeader(token: authToken, host: host) - let base64Auth = AppSyncJSONHelper.base64AuthenticationBlob(authHeader) - - let payloadData = Data(SubscriptionConstants.emptyPayload.utf8) - let payloadBase64 = payloadData.base64EncodedString() - - guard var urlComponents = URLComponents(url: request.url, resolvingAgainstBaseURL: false) else { - return request - } - let headerQuery = URLQueryItem(name: RealtimeProviderConstants.header, value: base64Auth) - let payloadQuery = URLQueryItem(name: RealtimeProviderConstants.payload, value: payloadBase64) - urlComponents.queryItems = [headerQuery, payloadQuery] - guard let url = urlComponents.url else { - return request - } - let signedRequest = AppSyncConnectionRequest(url: url) - return signedRequest - } - - private func getAuthToken() async -> AmplifyAuthTokenProvider.AuthToken? { - try? await authTokenProvider.getLatestAuthToken() - } -} - -// MARK: AuthorizationTokenAuthInterceptor + DefaultLogger -extension AuthenticationTokenAuthInterceptor: DefaultLogger { - public static var log: Logger { - Amplify.Logging.logger(forCategory: CategoryType.api.displayName, forNamespace: String(describing: self)) - } - public var log: Logger { - Self.log - } -} - -// MARK: - TokenAuthenticationHeader -/// Authentication header for user pool based auth -private class TokenAuthHeader: AuthenticationHeader { - let authorization: String - - init(token: String, host: String) { - self.authorization = token - super.init(host: host) - } - - private enum CodingKeys: String, CodingKey { - case authorization = "Authorization" - } - - override func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode(authorization, forKey: .authorization) - try super.encode(to: encoder) - } -} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift index 5598190755..c3d33320c2 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift @@ -6,20 +6,12 @@ // import Foundation -import AWSPluginsCore +@_spi(WebSocket) import AWSPluginsCore import Amplify -import AppSyncRealTimeClient import AWSClientRuntime import ClientRuntime -class IAMAuthInterceptor: AuthInterceptorAsync { - - private static let defaultLowercasedHeaderKeys: Set = [SubscriptionConstants.authorizationkey.lowercased(), - RealtimeProviderConstants.acceptKey.lowercased(), - RealtimeProviderConstants.contentEncodingKey.lowercased(), - RealtimeProviderConstants.contentTypeKey.lowercased(), - RealtimeProviderConstants.amzDate.lowercased(), - RealtimeProviderConstants.iamSecurityTokenKey.lowercased()] +class IAMAuthInterceptor { let authProvider: CredentialsProviding let region: AWSRegionType @@ -29,50 +21,11 @@ class IAMAuthInterceptor: AuthInterceptorAsync { self.region = region } - func interceptMessage(_ message: AppSyncMessage, for endpoint: URL) async -> AppSyncMessage { - switch message.messageType { - case .subscribe: - let authHeader = await getAuthHeader(endpoint, with: message.payload?.data ?? "") - var payload = message.payload ?? AppSyncMessage.Payload() - payload.authHeader = authHeader - let signedMessage = AppSyncMessage(id: message.id, - payload: payload, - type: message.messageType) - return signedMessage - default: - Amplify.API.log.verbose("Message type does not need signing - \(message.messageType)") - } - return message - } - - func interceptConnection(_ request: AppSyncConnectionRequest, - for endpoint: URL) async -> AppSyncConnectionRequest { - let url = endpoint.appendingPathComponent(RealtimeProviderConstants.iamConnectPath) - let payloadString = SubscriptionConstants.emptyPayload - guard let authHeader = await getAuthHeader(url, with: payloadString) else { - return request - } - let base64Auth = AppSyncJSONHelper.base64AuthenticationBlob(authHeader) - - let payloadData = Data(payloadString.utf8) - let payloadBase64 = payloadData.base64EncodedString() - - guard var urlComponents = URLComponents(url: request.url, resolvingAgainstBaseURL: false) else { - return request - } - let headerQuery = Foundation.URLQueryItem(name: RealtimeProviderConstants.header, value: base64Auth) - let payloadQuery = Foundation.URLQueryItem(name: RealtimeProviderConstants.payload, value: payloadBase64) - urlComponents.queryItems = [headerQuery, payloadQuery] - guard let signedUrl = urlComponents.url else { - return request - } - let signedRequest = AppSyncConnectionRequest(url: signedUrl) - return signedRequest - } - - func getAuthHeader(_ endpoint: URL, - with payload: String, - signer: AWSSignatureV4Signer = AmplifyAWSSignatureV4Signer()) async -> IAMAuthenticationHeader? { + func getAuthHeader( + _ endpoint: URL, + with payload: String, + signer: AWSSignatureV4Signer = AmplifyAWSSignatureV4Signer() + ) async -> AppSyncRealTimeRequestAuth.IAM? { guard let host = endpoint.host else { return nil } @@ -82,14 +35,14 @@ class IAMAuthInterceptor: AuthInterceptorAsync { /// 1. A request is created with the IAM based auth headers (date, accept, content encoding, content type, and /// additional headers. let requestBuilder = SdkHttpRequestBuilder() - .withHost(endpoint.host ?? "") + .withHost(host) .withPath(endpoint.path) .withMethod(.post) .withPort(443) .withProtocol(.https) - .withHeader(name: RealtimeProviderConstants.acceptKey, value: RealtimeProviderConstants.iamAccept) - .withHeader(name: RealtimeProviderConstants.contentEncodingKey, value: RealtimeProviderConstants.iamEncoding) - .withHeader(name: URLRequestConstants.Header.contentType, value: RealtimeProviderConstants.iamConentType) + .withHeader(name: "accept", value: "application/json, text/javascript") + .withHeader(name: "content-encoding", value: "amz-1.0") + .withHeader(name: URLRequestConstants.Header.contentType, value: "application/json; charset=UTF-8") .withHeader(name: URLRequestConstants.Header.host, value: host) .withBody(.data(Data(payload.utf8))) @@ -98,44 +51,34 @@ class IAMAuthInterceptor: AuthInterceptorAsync { do { guard let urlRequest = try await signer.sigV4SignedRequest(requestBuilder: requestBuilder, credentialsProvider: authProvider, - signingName: SubscriptionConstants.appsyncServiceName, + signingName: "appsync", signingRegion: region, date: Date()) else { Amplify.Logging.error("Unable to sign request") return nil } - var authorization: String = "" // TODO: Using long lived credentials without getting a session with security token will fail // since the session token does not exist on the signed request, and is an empty string. // Once Amplify.Auth is ready to be integrated, this code path needs to be re-tested. - var securityToken: String = "" - var amzDate: String = "" - var additionalHeaders: [String: String]? - for header in urlRequest.headers.headers { - guard let value = header.value.first else { - continue - } - let headerName = header.name.lowercased() - if headerName == SubscriptionConstants.authorizationkey.lowercased() { - authorization = value - } else if headerName == RealtimeProviderConstants.amzDate.lowercased() { - amzDate = value - } else if headerName == RealtimeProviderConstants.iamSecurityTokenKey.lowercased() { - securityToken = value - } else { - additionalHeaders?.updateValue(header.value.joined(separator: ","), forKey: header.name) + let headers = urlRequest.headers.headers.reduce([String: JSONValue]()) { partialResult, header in + switch header.name.lowercased() { + case "authorization", "x-amz-date", "x-amz-security-token": + guard let headerValue = header.value.first else { + return partialResult + } + return partialResult.merging([header.name.lowercased(): .string(headerValue)]) { $1 } + default: + return partialResult } } - return IAMAuthenticationHeader(host: host, - authorization: authorization, - securityToken: securityToken, - amzDate: amzDate, - accept: RealtimeProviderConstants.iamAccept, - contentEncoding: RealtimeProviderConstants.iamEncoding, - contentType: RealtimeProviderConstants.iamConentType, - additionalHeaders: additionalHeaders) + return .init( + host: host, + authToken: headers["authorization"]?.stringValue ?? "", + securityToken: headers["x-amz-security-token"]?.stringValue ?? "", + amzDate: headers["x-amz-date"]?.stringValue ?? "" + ) } catch { Amplify.Logging.error("Unable to sign request") return nil @@ -143,79 +86,35 @@ class IAMAuthInterceptor: AuthInterceptorAsync { } } -/// Stores the headers for an IAM based authentication. This object can be serialized to a JSON object and passed as the -/// headers value for establishing subscription connections. This is used as part of the overall interceptor logic -/// which expects a subclass of `AuthenticationHeader` to be returned. -/// See `IAMAuthInterceptor.getAuthHeader` for more details. -class IAMAuthenticationHeader: AuthenticationHeader { - let authorization: String - let securityToken: String - let amzDate: String - let accept: String - let contentEncoding: String - let contentType: String - - /// Additional headers that are not one of the expected headers in the request, but because additional headers are - /// also signed (and added the authorization header), they are required to be stored here to be further encoded. - let additionalHeaders: [String: String]? - - init(host: String, - authorization: String, - securityToken: String, - amzDate: String, - accept: String, - contentEncoding: String, - contentType: String, - additionalHeaders: [String: String]?) { - self.authorization = authorization - self.securityToken = securityToken - self.amzDate = amzDate - self.accept = accept - self.contentEncoding = contentEncoding - self.contentType = contentType - self.additionalHeaders = additionalHeaders - super.init(host: host) - } - - private struct DynamicCodingKeys: CodingKey { - var stringValue: String - init?(stringValue: String) { - self.stringValue = stringValue - } - var intValue: Int? - init?(intValue: Int) { - // We are not using this, thus just return nil. If we don't return nil, then it is expected all of the - // stored properties are initialized, forcing the implementation to have logic that maintains the two - // properties `stringValue` and `intValue`. Since we don't have a string representation of an int value - // and aren't using int values for determining the coding key, then simply return nil since the encoder - // will always pass in the header key string. - self.intValue = intValue - self.stringValue = "" - +extension IAMAuthInterceptor: WebSocketInterceptor { + func interceptConnection(url: URL) async -> URL { + let connectUrl = AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).appendingPathComponent("connect") + guard let authHeader = await getAuthHeader(connectUrl, with: "{}") else { + return connectUrl } + + return AppSyncRealTimeRequestAuth.URLQuery( + header: .iam(authHeader) + ).withBaseURL(url) } +} - override func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: DynamicCodingKeys.self) - // Force unwrapping when creating a `DynamicCodingKeys` will always be successful since the string constructor - // will never return nil even though the constructor is optional (conformance to CodingKey). - try container.encode(authorization, - forKey: DynamicCodingKeys(stringValue: SubscriptionConstants.authorizationkey)!) - try container.encode(securityToken, - forKey: DynamicCodingKeys(stringValue: RealtimeProviderConstants.iamSecurityTokenKey)!) - try container.encode(amzDate, - forKey: DynamicCodingKeys(stringValue: RealtimeProviderConstants.amzDate)!) - try container.encode(accept, - forKey: DynamicCodingKeys(stringValue: RealtimeProviderConstants.acceptKey)!) - try container.encode(contentEncoding, - forKey: DynamicCodingKeys(stringValue: RealtimeProviderConstants.contentEncodingKey)!) - try container.encode(contentType, - forKey: DynamicCodingKeys(stringValue: RealtimeProviderConstants.contentTypeKey)!) - if let headers = additionalHeaders { - for (key, value) in headers { - try container.encode(value, forKey: DynamicCodingKeys(stringValue: key)!) - } +extension IAMAuthInterceptor: AppSyncRequestInterceptor { + func interceptRequest( + event: AppSyncRealTimeRequest, + url: URL + ) async -> AppSyncRealTimeRequest { + guard case .start(let request) = event else { + return event } - try super.encode(to: encoder) + + let authHeader = await getAuthHeader( + AppSyncRealTimeClientFactory.appSyncApiEndpoint(url), + with: request.data) + return .start(.init( + id: request.id, + data: request.data, + auth: authHeader.map { .iam($0) } + )) } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/OIDCAuthProviderWrapper.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/OIDCAuthProviderWrapper.swift deleted file mode 100644 index 4093526d01..0000000000 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/OIDCAuthProviderWrapper.swift +++ /dev/null @@ -1,23 +0,0 @@ -// -// Copyright Amazon.com Inc. or its affiliates. -// All Rights Reserved. -// -// SPDX-License-Identifier: Apache-2.0 -// - -import Amplify -import Foundation -import AppSyncRealTimeClient - -class OIDCAuthProviderWrapper: OIDCAuthProviderAsync { - - let authTokenProvider: AmplifyAuthTokenProvider - - public init(authTokenProvider: AmplifyAuthTokenProvider) { - self.authTokenProvider = authTokenProvider - } - - func getLatestAuthToken() async throws -> String { - try await authTokenProvider.getLatestAuthToken() - } -} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift index bd457fcc3c..3e70654298 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift @@ -8,7 +8,7 @@ import Amplify import Foundation import AWSPluginsCore -import AppSyncRealTimeClient +import Combine public class AWSGraphQLSubscriptionTaskRunner: InternalTaskRunner, InternalTaskAsyncThrowingSequence, InternalTaskThrowingChannel { public typealias Request = GraphQLOperationRequest @@ -17,37 +17,49 @@ public class AWSGraphQLSubscriptionTaskRunner: InternalTaskRunner, public var request: GraphQLOperationRequest public var context = InternalTaskAsyncThrowingSequenceContext>() + var appSyncClient: AppSyncRealTimeClientProtocol? + var subscription: AnyCancellable? { + willSet { + self.subscription?.cancel() + } + } + let appSyncClientFactory: AppSyncRealTimeClientFactoryProtocol let pluginConfig: AWSAPICategoryPluginConfiguration - let subscriptionConnectionFactory: SubscriptionConnectionFactory let authService: AWSAuthServiceBehavior var apiAuthProviderFactory: APIAuthProviderFactory private let userAgent = AmplifyAWSServiceConfiguration.userAgentLib + private let subscriptionId = UUID().uuidString - var subscriptionConnection: SubscriptionConnection? - var subscriptionItem: SubscriptionItem? private var running = false - private let subscriptionQueue = DispatchQueue(label: "AWSGraphQLSubscriptionOperation.subscriptionQueue") - init(request: Request, pluginConfig: AWSAPICategoryPluginConfiguration, - subscriptionConnectionFactory: SubscriptionConnectionFactory, + appSyncClientFactory: AppSyncRealTimeClientFactoryProtocol, authService: AWSAuthServiceBehavior, apiAuthProviderFactory: APIAuthProviderFactory) { self.request = request self.pluginConfig = pluginConfig - self.subscriptionConnectionFactory = subscriptionConnectionFactory + self.appSyncClientFactory = appSyncClientFactory self.authService = authService self.apiAuthProviderFactory = apiAuthProviderFactory } public func cancel() { - subscriptionQueue.sync { - if let subscriptionItem = subscriptionItem, let subscriptionConnection = subscriptionConnection { - subscriptionConnection.unsubscribe(item: subscriptionItem) - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.disconnected) - send(subscriptionEvent) + self.send(GraphQLSubscriptionEvent.connection(.disconnected)) + Task { [weak self] in + guard let self else { + return + } + guard let appSyncClient = self.appSyncClient else { + return } + do { + try await appSyncClient.unsubscribe(id: self.subscriptionId) + } catch { + print("[AWSGraphQLSubscriptionTaskRunner] Failed to unsubscribe \(self.subscriptionId)") + } + + await appSyncClient.disconnectWhenIdel() } } @@ -79,33 +91,28 @@ public class AWSGraphQLSubscriptionTaskRunner: InternalTaskRunner, return } - // Retrieve request plugin option and - // auth type in case of a multi-auth setup let pluginOptions = request.options.pluginOptions as? AWSAPIPluginDataStoreOptions - let urlRequest = generateSubscriptionURLRequest(from: endpointConfig) - // Retrieve the subscription connection - subscriptionQueue.sync { - do { - subscriptionConnection = try subscriptionConnectionFactory - .getOrCreateConnection(for: endpointConfig, - urlRequest: urlRequest, - authService: authService, - authType: pluginOptions?.authType, - apiAuthProviderFactory: apiAuthProviderFactory) - } catch { - let error = APIError.operationError("Unable to get connection for api \(endpointConfig.name)", "", error) - fail(error) - return - } + do { + self.appSyncClient = try await appSyncClientFactory.getAppSyncRealTimeClient( + for: endpointConfig, + endpoint: endpointConfig.baseURL, + authService: authService, + authType: pluginOptions?.authType, + apiAuthProviderFactory: apiAuthProviderFactory + ) // Create subscription - - subscriptionItem = subscriptionConnection?.subscribe(requestString: request.document, - variables: request.variables, - eventHandler: { [weak self] event, _ in + self.subscription = try await appSyncClient?.subscribe( + id: subscriptionId, + query: encodeRequest(query: request.document, variables: request.variables) + ).sink(receiveValue: { [weak self] event in self?.onAsyncSubscriptionEvent(event: event) }) + } catch { + let error = APIError.operationError("Unable to get connection for api \(endpointConfig.name)", "", error) + fail(error) + return } } @@ -119,29 +126,22 @@ public class AWSGraphQLSubscriptionTaskRunner: InternalTaskRunner, // MARK: - Subscription callbacks - private func onAsyncSubscriptionEvent(event: SubscriptionItemEvent) { + private func onAsyncSubscriptionEvent(event: AppSyncSubscriptionEvent) { switch event { - case .connection(let subscriptionConnectionEvent): - onSubscriptionEvent(subscriptionConnectionEvent) - case .data(let data): + case .data(let json): + guard let data = try? JSONEncoder().encode(json) else { + return + } onGraphQLResponseData(data) - case .failed(let error): - onSubscriptionFailure(error) - } - } - - private func onSubscriptionEvent(_ subscriptionConnectionEvent: SubscriptionConnectionEvent) { - switch subscriptionConnectionEvent { - case .connecting: - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.connecting) - send(subscriptionEvent) - case .connected: - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.connected) - send(subscriptionEvent) - case .disconnected: - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.disconnected) - send(subscriptionEvent) + case .subscribing: + send(GraphQLSubscriptionEvent.connection(.connecting)) + case .subscribed: + send(GraphQLSubscriptionEvent.connection(.connected)) + case .unsubscribed: + send(GraphQLSubscriptionEvent.connection(.disconnected)) finish() + case .error(let errors): + fail(toAPIError(errors, type: R.self)) } } @@ -171,56 +171,36 @@ public class AWSGraphQLSubscriptionTaskRunner: InternalTaskRunner, } } - private func onSubscriptionFailure(_ error: Error) { - var errorDescription = "Subscription item event failed with error" - if case let ConnectionProviderError.subscription(_, payload) = error, - let errors = payload?["errors"] as? AppSyncJSONValue, - let graphQLErrors = try? GraphQLErrorDecoder.decodeAppSyncErrors(errors) { - - if graphQLErrors.hasUnauthorizedError() { - errorDescription += ": \(APIError.UnauthorizedMessageString)" - } - - let graphQLResponseError = GraphQLResponseError.error(graphQLErrors) - fail(APIError.operationError(errorDescription, "", graphQLResponseError)) - return - } else if case ConnectionProviderError.unauthorized = error { - errorDescription += ": \(APIError.UnauthorizedMessageString)" - } else if case ConnectionProviderError.connection = error { - errorDescription += ": connection" - let error = URLError(.networkConnectionLost) - fail(APIError.networkError(errorDescription, nil, error)) - return - } - - fail(APIError.operationError(errorDescription, "", error)) - } } // Class is still necessary. See https://github.com/aws-amplify/amplify-swift/issues/2252 final public class AWSGraphQLSubscriptionOperation: GraphQLSubscriptionOperation { let pluginConfig: AWSAPICategoryPluginConfiguration - let subscriptionConnectionFactory: SubscriptionConnectionFactory + let appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol let authService: AWSAuthServiceBehavior private let userAgent = AmplifyAWSServiceConfiguration.userAgentLib - var subscriptionConnection: SubscriptionConnection? - var subscriptionItem: SubscriptionItem? - var apiAuthProviderFactory: APIAuthProviderFactory + var appSyncRealTimeClient: AppSyncRealTimeClientProtocol? + var subscription: AnyCancellable? { + willSet { + self.subscription?.cancel() + } + } - private let subscriptionQueue = DispatchQueue(label: "AWSGraphQLSubscriptionOperation.subscriptionQueue") + var apiAuthProviderFactory: APIAuthProviderFactory + private let subscriptionId = UUID().uuidString init(request: GraphQLOperationRequest, pluginConfig: AWSAPICategoryPluginConfiguration, - subscriptionConnectionFactory: SubscriptionConnectionFactory, + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol, authService: AWSAuthServiceBehavior, apiAuthProviderFactory: APIAuthProviderFactory, inProcessListener: AWSGraphQLSubscriptionOperation.InProcessListener?, resultListener: AWSGraphQLSubscriptionOperation.ResultListener?) { self.pluginConfig = pluginConfig - self.subscriptionConnectionFactory = subscriptionConnectionFactory + self.appSyncRealTimeClientFactory = appSyncRealTimeClientFactory self.authService = authService self.apiAuthProviderFactory = apiAuthProviderFactory @@ -232,17 +212,26 @@ final public class AWSGraphQLSubscriptionOperation: GraphQLSubscri } override public func cancel() { - subscriptionQueue.sync { - if let subscriptionItem = subscriptionItem, let subscriptionConnection = subscriptionConnection { - subscriptionConnection.unsubscribe(item: subscriptionItem) - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.disconnected) - dispatchInProcess(data: subscriptionEvent) + super.cancel() + + Task { [weak self] in + guard let self else { + return } - } - dispatch(result: .successfulVoid) - super.cancel() - finish() + guard let appSyncRealTimeClient = self.appSyncRealTimeClient else { + return + } + + do { + try await appSyncRealTimeClient.unsubscribe(id: subscriptionId) + finish() + } catch { + print("[AWSGraphQLSubscriptionOperation] Failed to unsubscribe \(subscriptionId), error: \(error)") + } + + await appSyncRealTimeClient.disconnectWhenIdel() + } } override public func main() { @@ -278,20 +267,24 @@ final public class AWSGraphQLSubscriptionOperation: GraphQLSubscri return } - // Retrieve request plugin option and - // auth type in case of a multi-auth setup let pluginOptions = request.options.pluginOptions as? AWSAPIPluginDataStoreOptions - let urlRequest = generateSubscriptionURLRequest(from: endpointConfig) - - // Retrieve the subscription connection - subscriptionQueue.sync { + Task { do { - subscriptionConnection = try subscriptionConnectionFactory - .getOrCreateConnection(for: endpointConfig, - urlRequest: urlRequest, - authService: authService, - authType: pluginOptions?.authType, - apiAuthProviderFactory: apiAuthProviderFactory) + appSyncRealTimeClient = try await appSyncRealTimeClientFactory.getAppSyncRealTimeClient( + for: endpointConfig, + endpoint: endpointConfig.baseURL, + authService: authService, + authType: pluginOptions?.authType, + apiAuthProviderFactory: apiAuthProviderFactory + ) + + // Create subscription + self.subscription = try await appSyncRealTimeClient?.subscribe( + id: subscriptionId, + query: encodeRequest(query: request.document, variables: request.variables) + ).sink(receiveValue: { [weak self] event in + self?.onAsyncSubscriptionEvent(event: event) + }) } catch { let error = APIError.operationError("Unable to get connection for api \(endpointConfig.name)", "", error) dispatch(result: .failure(error)) @@ -299,13 +292,6 @@ final public class AWSGraphQLSubscriptionOperation: GraphQLSubscri return } - // Create subscription - - subscriptionItem = subscriptionConnection?.subscribe(requestString: request.document, - variables: request.variables, - eventHandler: { [weak self] event, _ in - self?.onAsyncSubscriptionEvent(event: event) - }) } } @@ -319,30 +305,24 @@ final public class AWSGraphQLSubscriptionOperation: GraphQLSubscri // MARK: - Subscription callbacks - private func onAsyncSubscriptionEvent(event: SubscriptionItemEvent) { + private func onAsyncSubscriptionEvent(event: AppSyncSubscriptionEvent) { switch event { - case .connection(let subscriptionConnectionEvent): - onSubscriptionEvent(subscriptionConnectionEvent) - case .data(let data): + case .data(let json): + guard let data = try? JSONEncoder().encode(json) else { + return + } onGraphQLResponseData(data) - case .failed(let error): - onSubscriptionFailure(error) - } - } - - private func onSubscriptionEvent(_ subscriptionConnectionEvent: SubscriptionConnectionEvent) { - switch subscriptionConnectionEvent { - case .connecting: - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.connecting) - dispatchInProcess(data: subscriptionEvent) - case .connected: - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.connected) - dispatchInProcess(data: subscriptionEvent) - case .disconnected: - let subscriptionEvent = GraphQLSubscriptionEvent.connection(.disconnected) - dispatchInProcess(data: subscriptionEvent) + case .subscribing: + dispatchInProcess(data: GraphQLSubscriptionEvent.connection(.connecting)) + case .subscribed: + dispatchInProcess(data: GraphQLSubscriptionEvent.connection(.connected)) + case .unsubscribed: + dispatchInProcess(data: GraphQLSubscriptionEvent.connection(.disconnected)) dispatch(result: .successfulVoid) finish() + case .error(let errors): + dispatch(result: .failure(toAPIError(errors, type: R.self))) + finish() } } @@ -374,43 +354,53 @@ final public class AWSGraphQLSubscriptionOperation: GraphQLSubscri finish() } } +} - private func onSubscriptionFailure(_ error: Error) { - var errorDescription = "Subscription item event failed with error" - if case let ConnectionProviderError.subscription(_, payload) = error, - let errors = payload?["errors"] as? AppSyncJSONValue, - let graphQLErrors = try? GraphQLErrorDecoder.decodeAppSyncErrors(errors) { +fileprivate func encodeRequest(query: String, variables: [String: Any]?) -> String { + var json: [String: Any] = [ + "query": query + ] - if graphQLErrors.hasUnauthorizedError() { - errorDescription += ": \(APIError.UnauthorizedMessageString)" - } + if let variables { + json["variables"] = variables + } - let graphQLResponseError = GraphQLResponseError.error(graphQLErrors) - dispatch(result: .failure(APIError.operationError(errorDescription, "", graphQLResponseError))) - finish() - return - } else if case ConnectionProviderError.unauthorized = error { - errorDescription += ": \(APIError.UnauthorizedMessageString)" - } else if case ConnectionProviderError.connection = error { - errorDescription += ": connection" - let error = URLError(.networkConnectionLost) - dispatch(result: .failure(APIError.networkError(errorDescription, nil, error))) - finish() - return - } - dispatch(result: .failure(APIError.operationError(errorDescription, "", error))) - finish() + do { + return String(data: try JSONSerialization.data(withJSONObject: json), encoding: .utf8)! + } catch { + return "" } } -extension Array where Element == GraphQLError { - func hasUnauthorizedError() -> Bool { - contains { graphQLError in - if case let .string(errorTypeValue) = graphQLError.extensions?["errorType"], - case .unauthorized = AppSyncErrorType(errorTypeValue) { - return true - } - return false - } +fileprivate func toAPIError(_ errors: [Error], type: R.Type) -> APIError { + func errorDescription(_ hasAuthorizationError: Bool = false) -> String { + "Subscription item event failed with error" + + (hasAuthorizationError ? ": \(APIError.UnauthorizedMessageString)" : "") + } + + switch errors { + case let errors as [AppSyncRealTimeRequest.Error]: + let hasAuthorizationError = errors.contains(where: { $0 == .unauthorized}) + return APIError.operationError( + errorDescription(hasAuthorizationError), + "", + errors.first + ) + case let errors as [GraphQLError]: + let hasAuthorizationError = errors.map(\.extensions) + .compactMap { $0.flatMap { $0["errorType"]?.stringValue } } + .contains(where: { AppSyncErrorType($0) == .unauthorized }) + return APIError.operationError( + errorDescription(hasAuthorizationError), + "", + GraphQLResponseError.error(errors) + ) + default: + return APIError.operationError( + errorDescription(), + "", + errors.first + ) } + } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AWSOIDCAuthProvider.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AWSOIDCAuthProvider.swift index b963ee3046..1887aa1b06 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AWSOIDCAuthProvider.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AWSOIDCAuthProvider.swift @@ -7,9 +7,8 @@ import Foundation import AWSPluginsCore -import AppSyncRealTimeClient -class AWSOIDCAuthProvider: OIDCAuthProviderAsync { +class AWSOIDCAuthProvider { var authService: AWSAuthServiceBehavior diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AWSSubscriptionConnectionFactory.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AWSSubscriptionConnectionFactory.swift deleted file mode 100644 index 38404de216..0000000000 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AWSSubscriptionConnectionFactory.swift +++ /dev/null @@ -1,102 +0,0 @@ -// -// Copyright Amazon.com Inc. or its affiliates. -// All Rights Reserved. -// -// SPDX-License-Identifier: Apache-2.0 -// - -import Foundation -import AWSPluginsCore -import Amplify -import AppSyncRealTimeClient - -class AWSSubscriptionConnectionFactory: SubscriptionConnectionFactory { - /// Key used to map an API to a ConnectionProvider - private struct MapperCacheKey: Hashable { - let apiName: String - let authType: AWSAuthorizationType? - } - - private let concurrencyQueue = DispatchQueue(label: "com.amazonaws.amplify.AWSSubscriptionConnectionFactory", - target: DispatchQueue.global()) - - private var apiToConnectionProvider: [MapperCacheKey: ConnectionProvider] = [:] - - func getOrCreateConnection( - for endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig, - urlRequest: URLRequest, - authService: AWSAuthServiceBehavior, - authType: AWSAuthorizationType? = nil, - apiAuthProviderFactory: APIAuthProviderFactory - ) throws -> SubscriptionConnection { - return try concurrencyQueue.sync { - let apiName = endpointConfig.name - - let authInterceptor = try self.getInterceptor( - for: self.getOrCreateAuthConfiguration(from: endpointConfig, authType: authType), - authService: authService, - apiAuthProviderFactory: apiAuthProviderFactory - ) - - // create or retrieve the connection provider. If creating, add interceptors onto the provider. - let connectionProvider = apiToConnectionProvider[MapperCacheKey(apiName: apiName, authType: authType)] ?? - ConnectionProviderFactory.createConnectionProviderAsync(for: urlRequest, - authInterceptor: authInterceptor, - connectionType: .appSyncRealtime) - - // store the connection provider for this api - apiToConnectionProvider[MapperCacheKey(apiName: apiName, authType: authType)] = connectionProvider - - // create a subscription connection for subscribing and unsubscribing on the connection provider - return AppSyncSubscriptionConnection(provider: connectionProvider) - } - } - - // MARK: Private methods - - private func getOrCreateAuthConfiguration(from endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig, - authType: AWSAuthorizationType?) throws -> AWSAuthorizationConfiguration { - // create a configuration if there's an override auth type - if let authType = authType { - return try endpointConfig.authorizationConfigurationFor(authType: authType) - } - - return endpointConfig.authorizationConfiguration - } - - private func getInterceptor(for authorizationConfiguration: AWSAuthorizationConfiguration, - authService: AWSAuthServiceBehavior, - apiAuthProviderFactory: APIAuthProviderFactory) throws -> AuthInterceptorAsync { - let authInterceptor: AuthInterceptorAsync - - switch authorizationConfiguration { - case .apiKey(let apiKeyConfiguration): - authInterceptor = APIKeyAuthInterceptor(apiKeyConfiguration.apiKey) - case .amazonCognitoUserPools: - let provider = AWSOIDCAuthProvider(authService: authService) - authInterceptor = OIDCAuthInterceptorAsync(provider) - case .awsIAM(let awsIAMConfiguration): - authInterceptor = IAMAuthInterceptor(authService.getCredentialsProvider(), - region: awsIAMConfiguration.region) - case .openIDConnect: - guard let oidcAuthProvider = apiAuthProviderFactory.oidcAuthProvider() else { - throw APIError.invalidConfiguration( - "Using openIDConnect requires passing in an APIAuthProvider with an OIDC AuthProvider", - "When instantiating AWSAPIPlugin pass in an instance of APIAuthProvider", nil) - } - let wrappedProvider = OIDCAuthProviderWrapper(authTokenProvider: oidcAuthProvider) - authInterceptor = OIDCAuthInterceptorAsync(wrappedProvider) - case .function: - guard let functionAuthProvider = apiAuthProviderFactory.functionAuthProvider() else { - throw APIError.invalidConfiguration( - "Using function as auth provider requires passing in an APIAuthProvider with a Function AuthProvider", - "When instantiating AWSAPIPlugin pass in an instance of APIAuthProvider", nil) - } - authInterceptor = AuthenticationTokenAuthInterceptor(authTokenProvider: functionAuthProvider) - case .none: - throw APIError.unknown("Cannot create AppSync subscription for none auth mode", "") - } - - return authInterceptor - } -} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AppSyncRealTimeClientFactory.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AppSyncRealTimeClientFactory.swift new file mode 100644 index 0000000000..b97459c0e1 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/AppSyncRealTimeClientFactory.swift @@ -0,0 +1,191 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Amplify +import Combine +@_spi(WebSocket) import AWSPluginsCore + +protocol AppSyncRealTimeClientFactoryProtocol { + func getAppSyncRealTimeClient( + for endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig, + endpoint: URL, + authService: AWSAuthServiceBehavior, + authType: AWSAuthorizationType?, + apiAuthProviderFactory: APIAuthProviderFactory + ) async throws -> AppSyncRealTimeClientProtocol +} + +protocol AppSyncRealTimeClientProtocol { + func connect() async throws + func disconnectWhenIdel() async + func disconnect() async + func subscribe(id: String, query: String) async throws -> AnyPublisher + func unsubscribe(id: String) async throws +} + +actor AppSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol { + struct MapperCacheKey: Hashable { + let apiName: String + let authType: AWSAuthorizationType? + } + + public private(set) var apiToClientCache = [MapperCacheKey: AppSyncRealTimeClientProtocol]() + + public func getAppSyncRealTimeClient( + for endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig, + endpoint: URL, + authService: AWSAuthServiceBehavior, + authType: AWSAuthorizationType? = nil, + apiAuthProviderFactory: APIAuthProviderFactory + ) throws -> AppSyncRealTimeClientProtocol { + let apiName = endpointConfig.name + + let authInterceptor = try self.getInterceptor( + for: self.getOrCreateAuthConfiguration(from: endpointConfig, authType: authType), + authService: authService, + apiAuthProviderFactory: apiAuthProviderFactory + ) + + // create or retrieve the connection provider. If creating, add interceptors onto the provider. + if let appSyncClient = apiToClientCache[MapperCacheKey(apiName: apiName, authType: authType)] { + return appSyncClient + } else { + let appSyncClient = AppSyncRealTimeClient( + endpoint: endpoint, + requestInterceptor: authInterceptor, + webSocketClient: WebSocketClient( + url: Self.appSyncRealTimeEndpoint(endpoint), + protocols: ["graphql-ws"], + interceptor: authInterceptor + ) + ) + + // store the connection provider for this api + apiToClientCache[MapperCacheKey(apiName: apiName, authType: authType)] = appSyncClient + // create a subscription connection for subscribing and unsubscribing on the connection provider + return appSyncClient + } + } + + private func getOrCreateAuthConfiguration( + from endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig, + authType: AWSAuthorizationType? + ) throws -> AWSAuthorizationConfiguration { + // create a configuration if there's an override auth type + if let authType = authType { + return try endpointConfig.authorizationConfigurationFor(authType: authType) + } + + return endpointConfig.authorizationConfiguration + } + + private func getInterceptor( + for authorizationConfiguration: AWSAuthorizationConfiguration, + authService: AWSAuthServiceBehavior, + apiAuthProviderFactory: APIAuthProviderFactory + ) throws -> AppSyncRequestInterceptor & WebSocketInterceptor { + switch authorizationConfiguration { + case .apiKey(let apiKeyConfiguration): + return APIKeyAuthInterceptor(apiKey: apiKeyConfiguration.apiKey) + case .amazonCognitoUserPools: + let provider = AWSOIDCAuthProvider(authService: authService) + return AuthTokenInterceptor(getLatestAuthToken: provider.getLatestAuthToken) + case .awsIAM(let awsIAMConfiguration): + return IAMAuthInterceptor(authService.getCredentialsProvider(), + region: awsIAMConfiguration.region) + case .openIDConnect: + guard let oidcAuthProvider = apiAuthProviderFactory.oidcAuthProvider() else { + throw APIError.invalidConfiguration( + "Using openIDConnect requires passing in an APIAuthProvider with an OIDC AuthProvider", + "When instantiating AWSAPIPlugin pass in an instance of APIAuthProvider", nil) + } + return AuthTokenInterceptor(getLatestAuthToken: oidcAuthProvider.getLatestAuthToken) + case .function: + guard let functionAuthProvider = apiAuthProviderFactory.functionAuthProvider() else { + throw APIError.invalidConfiguration( + "Using function as auth provider requires passing in an APIAuthProvider with a Function AuthProvider", + "When instantiating AWSAPIPlugin pass in an instance of APIAuthProvider", nil) + } + return AuthTokenInterceptor(authTokenProvider: functionAuthProvider) + case .none: + throw APIError.unknown("Cannot create AppSync subscription for none auth mode", "") + } + } +} + + +extension AppSyncRealTimeClientFactory { + + /** + Converting appsync api url to realtime api url + 1. api.example.com/graphql -> api.example.com/graphql/realtime + 2. abc.appsync-api.us-east-1.amazonaws.com/graphql -> abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql + */ + static func appSyncRealTimeEndpoint(_ url: URL) -> URL { + guard let host = url.host else { + return url + } + + guard host.hasSuffix("amazonaws.com") else { + return url.appendingPathComponent("realtime") + } + + guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + return url + } + + urlComponents.host = host.replacingOccurrences(of: "appsync-api", with: "appsync-realtime-api") + guard let realTimeUrl = urlComponents.url else { + return url + } + + return realTimeUrl + } + + /** + Converting appsync realtime api url to api url + 1. api.example.com/graphql/realtime -> api.example.com/graphql + 2. abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql -> abc.appsync-api.us-east-1.amazonaws.com/graphql + */ + static func appSyncApiEndpoint(_ url: URL) -> URL { + guard let host = url.host else { + return url + } + + guard host.hasSuffix("amazonaws.com") else { + if url.lastPathComponent == "realtime" { + return url.deletingLastPathComponent() + } + return url + } + + guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + return url + } + + urlComponents.host = host.replacingOccurrences(of: "appsync-realtime-api", with: "appsync-api") + guard let apiUrl = urlComponents.url else { + return url + } + return apiUrl + } +} + +extension AppSyncRealTimeClientFactory: Resettable { + func reset() async { + await withTaskGroup(of: Void.self) { taskGroup in + self.apiToClientCache.values + .compactMap { $0 as? Resettable } + .forEach { resettable in + taskGroup.addTask { await resettable.reset()} + } + await taskGroup.waitForAll() + } + } +} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/SubscriptionConnectionFactory.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/SubscriptionConnectionFactory.swift deleted file mode 100644 index 98c5f6aa4d..0000000000 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/SubscriptionFactory/SubscriptionConnectionFactory.swift +++ /dev/null @@ -1,23 +0,0 @@ -// -// Copyright Amazon.com Inc. or its affiliates. -// All Rights Reserved. -// -// SPDX-License-Identifier: Apache-2.0 -// - -import Foundation - -import Amplify -import AWSPluginsCore -import AppSyncRealTimeClient - -/// Protocol for the subscription factory -protocol SubscriptionConnectionFactory { - - /// Get connection based on the connection type - func getOrCreateConnection(for endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig, - urlRequest: URLRequest, - authService: AWSAuthServiceBehavior, - authType: AWSAuthorizationType?, - apiAuthProviderFactory: APIAuthProviderFactory) throws -> SubscriptionConnection -} diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Decode/GraphQLErrorDecoder.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Decode/GraphQLErrorDecoder.swift index 7e7b1dcb31..f6c80c047c 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Decode/GraphQLErrorDecoder.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Decode/GraphQLErrorDecoder.swift @@ -6,7 +6,6 @@ // import Amplify -import AppSyncRealTimeClient import Foundation struct GraphQLErrorDecoder { @@ -29,12 +28,13 @@ struct GraphQLErrorDecoder { return responseErrors } - static func decodeAppSyncErrors(_ appSyncJSON: AppSyncJSONValue) throws -> [GraphQLError] { - guard case let .array(errors) = appSyncJSON else { - throw APIError.unknown("Expected 'errors' field not found in \(String(describing: appSyncJSON))", "", nil) + static func decodeAppSyncErrors(_ payload: JSONValue?) throws -> [GraphQLError] { + guard let errorsJson = payload?.errors else { + throw APIError.unknown("Expected 'errors' field not found in \(String(describing: payload))", "", nil) } - let convertedValues = errors.map(AppSyncJSONValue.toJSONValue) - return try decodeErrors(graphQLErrors: convertedValues) + + let errors = errorsJson.asArray ?? [errorsJson] + return try decodeErrors(graphQLErrors: errors) } static func decode(graphQLErrorJSON: JSONValue) throws -> GraphQLError { diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/AppSyncJSONValue+toJSONValue.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/AppSyncJSONValue+toJSONValue.swift deleted file mode 100644 index 6fba0559e8..0000000000 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/AppSyncJSONValue+toJSONValue.swift +++ /dev/null @@ -1,32 +0,0 @@ -// -// Copyright Amazon.com Inc. or its affiliates. -// All Rights Reserved. -// -// SPDX-License-Identifier: Apache-2.0 -// - -import Foundation -import Amplify -import AppSyncRealTimeClient - -extension AppSyncJSONValue { - static func toJSONValue(_ json: AppSyncJSONValue) -> JSONValue { - switch json { - case .array(let values): - return JSONValue.array(values.map(AppSyncJSONValue.toJSONValue)) - case .boolean(let value): - return JSONValue.boolean(value) - case .null: - return JSONValue.null - case .number(let value): - return JSONValue.number(value) - case .object(let content): - return JSONValue.object(content.reduce(into: [:]) { acc, partial in - let (key, value) = partial - acc[key] = AppSyncJSONValue.toJSONValue(value) - }) - case .string(let value): - return JSONValue.string(value) - } - } -} diff --git a/AmplifyPlugins/API/Tests/APIHostApp/APIHostApp.xcodeproj/project.pbxproj b/AmplifyPlugins/API/Tests/APIHostApp/APIHostApp.xcodeproj/project.pbxproj index e44cd6e59a..495154738f 100644 --- a/AmplifyPlugins/API/Tests/APIHostApp/APIHostApp.xcodeproj/project.pbxproj +++ b/AmplifyPlugins/API/Tests/APIHostApp/APIHostApp.xcodeproj/project.pbxproj @@ -227,6 +227,7 @@ 39E0F2AA28A440A700939D9F /* GraphQLWithUserPoolIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21698BBD28899B6D004BD994 /* GraphQLWithUserPoolIntegrationTests.swift */; }; 39E0F2AD28A441B100939D9F /* TestConfigHelper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 39E0F2AC28A441B100939D9F /* TestConfigHelper.swift */; }; 39E0F2AF28A4425C00939D9F /* Todo.swift in Sources */ = {isa = PBXBuildFile; fileRef = 39E0F2AE28A4425C00939D9F /* Todo.swift */; }; + 606C8B792B895E5A00716094 /* AppSyncRealTimeClientTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 606C8B782B895E5A00716094 /* AppSyncRealTimeClientTests.swift */; }; 681B35422A43962D0074F369 /* Team2+Schema.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2126271F289ABFE9003788E3 /* Team2+Schema.swift */; }; 681B35432A43962D0074F369 /* EnumTestModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21262707289ABFE6003788E3 /* EnumTestModel.swift */; }; 681B35442A43962D0074F369 /* ScalarContainer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 21262720289ABFE9003788E3 /* ScalarContainer.swift */; }; @@ -676,6 +677,7 @@ 39E0F2A128A43FB100939D9F /* AWSAPIPluginGraphQLUserPoolTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = AWSAPIPluginGraphQLUserPoolTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; 39E0F2AC28A441B100939D9F /* TestConfigHelper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestConfigHelper.swift; sourceTree = ""; }; 39E0F2AE28A4425C00939D9F /* Todo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Todo.swift; sourceTree = ""; }; + 606C8B782B895E5A00716094 /* AppSyncRealTimeClientTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppSyncRealTimeClientTests.swift; sourceTree = ""; }; 681B35292A4395730074F369 /* APIWatchApp.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = APIWatchApp.app; sourceTree = BUILT_PRODUCTS_DIR; }; 681B35892A43962D0074F369 /* AWSAPIPluginFunctionalTestsWatch.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = AWSAPIPluginFunctionalTestsWatch.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; 681B35A12A4396CF0074F369 /* AWSAPIPluginGraphQLLambdaAuthTestsWatch.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = AWSAPIPluginGraphQLLambdaAuthTestsWatch.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -899,6 +901,7 @@ children = ( 21E581E32A6835910027D13A /* API.swift */, 212626CA289ABC79003788E3 /* Base */, + 606C8B782B895E5A00716094 /* AppSyncRealTimeClientTests.swift */, 21698AA82889996A004BD994 /* GraphQLConnectionScenario1Tests.swift */, 21E581E12A6707900027D13A /* GraphQLConnectionScenario1APISwiftTests.swift */, 21698AB62889996A004BD994 /* GraphQLConnectionScenario2Tests.swift */, @@ -2180,6 +2183,7 @@ 2126273D289ABFEB003788E3 /* Blog6.swift in Sources */, 2126274B289ABFEB003788E3 /* Comment.swift in Sources */, 2126273E289ABFEB003788E3 /* Post6.swift in Sources */, + 606C8B792B895E5A00716094 /* AppSyncRealTimeClientTests.swift in Sources */, 2126275F289ABFEB003788E3 /* Post3.swift in Sources */, 21262757289ABFEB003788E3 /* User5+Schema.swift in Sources */, 21262754289ABFEB003788E3 /* Blog6+Schema.swift in Sources */, diff --git a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/AppSyncRealTimeClientTests.swift b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/AppSyncRealTimeClientTests.swift new file mode 100644 index 0000000000..4084cd618b --- /dev/null +++ b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/AppSyncRealTimeClientTests.swift @@ -0,0 +1,200 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +import Combine +@testable import Amplify +@testable import AWSAPIPlugin +@testable @_spi(WebSocket) import AWSPluginsCore + +class AppSyncRealTimeClientTests: XCTestCase { + let subscriptionRequest = """ + subscription MySubscription { + onCreatePost { + content + createdAt + draft + id + rating + status + title + updatedAt + } + } + """ + + var appSyncRealTimeClient: AppSyncRealTimeClient? + + override func setUp() async throws { + do { + Amplify.Logging.logLevel = .verbose + + let data = try TestConfigHelper.retrieve( + forResource: GraphQLModelBasedTests.amplifyConfiguration + ) + + let amplifyConfig = try JSONDecoder().decode(JSONValue.self, from: data) + let (endpoint, apiKey) = (amplifyConfig.api?.plugins?.awsAPIPlugin?.asObject?.values + .map { ($0.endpoint?.stringValue, $0.apiKey?.stringValue)} + .first { $0.0 != nil && $0.1 != nil } + .map { ($0.0!, $0.1!) })! + + + let webSocketClient = WebSocketClient( + url: AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(URL(string: endpoint)!), + protocols: ["graphql-ws"], + interceptor: APIKeyAuthInterceptor(apiKey: apiKey) + ) + appSyncRealTimeClient = AppSyncRealTimeClient( + endpoint: URL(string: endpoint)!, + requestInterceptor: APIKeyAuthInterceptor(apiKey: apiKey), + webSocketClient: webSocketClient + ) + + } catch { + XCTFail("Failed to setup appSyncRealTimeClient: \(error)") + } + } + + override func tearDown() async throws { + await appSyncRealTimeClient?.reset() + appSyncRealTimeClient = nil + } + + func testSubscribe_withSubscriptionConnection() async throws { + var cancellables = Set() + let subscribedExpectation = expectation(description: "Subscription established") + + try await appSyncRealTimeClient?.connect() + try await makeOneSubscription { event in + if case .subscribed = event { + subscribedExpectation.fulfill() + } + }?.store(in: &cancellables) + + await fulfillment(of: [subscribedExpectation], timeout: 5) + withExtendedLifetime(cancellables, { }) + } + + func testMultThreads_withConnectedClient_subscribeAndUnsubscribe() async throws { + var cancellables = [AnyCancellable?]() + let concurrentFactor = 90 + let expectedSubscription = expectation(description: "Multi threads subscription") + expectedSubscription.expectedFulfillmentCount = concurrentFactor + + let expectedUnsubscription = expectation(description: "Multi threads unsubscription") + expectedUnsubscription.expectedFulfillmentCount = concurrentFactor + cancellables = try await withThrowingTaskGroup( + of: AnyCancellable?.self, + returning: [AnyCancellable?].self + ) { taskGroup in + (0.. AnyCancellable? in + guard let self else { return nil } + let subscription = try await self.makeOneSubscription(id: id) { + if case .subscribed = $0 { + expectedSubscription.fulfill() + Task { + try await self.appSyncRealTimeClient?.unsubscribe(id: id) + } + } else if case .unsubscribed = $0 { + expectedUnsubscription.fulfill() + } + } + + return subscription + } + + } + + return try await taskGroup.reduce([AnyCancellable?]()) { $0 + [$1] } + } + + await fulfillment(of: [expectedSubscription, expectedUnsubscription], timeout: 3) + withExtendedLifetime(cancellables, { }) + } + + func testMaxSubscriptionReached_throwMaxSubscriptionsReachedError() async throws { + let numOfMaxSubscriptionCount = 100 + let maxSubsctiptionsSuccess = expectation(description: "Client can subscribe to max subscription count") + maxSubsctiptionsSuccess.expectedFulfillmentCount = numOfMaxSubscriptionCount + + var cancellables = try await withThrowingTaskGroup( + of: AnyCancellable?.self, + returning: [AnyCancellable?].self + ) { taskGroup in + (0.. AnyCancellable? in + guard let self else { return nil } + let subscription = try await self.makeOneSubscription(id: id) { + if case .subscribed = $0 { + maxSubsctiptionsSuccess.fulfill() + } + } + + return subscription + } + } + return try await taskGroup.reduce([AnyCancellable?]()) { $0 + [$1] } + } + + await fulfillment(of: [maxSubsctiptionsSuccess], timeout: 2) + + let maxSubscriptionReachedError = expectation(description: "Should return max subscription reached error") + maxSubscriptionReachedError.assertForOverFulfill = false + let retryTriggerredAndSucceed = expectation(description: "Retry on max subscription reached error and succeed") + cancellables.append(try await makeOneSubscription { event in + if case .error(let errors) = event { + XCTAssertTrue(errors.count == 1) + XCTAssertTrue(errors[0] is AppSyncRealTimeRequest.Error) + if case .maxSubscriptionsReached = errors[0] as! AppSyncRealTimeRequest.Error { + maxSubscriptionReachedError.fulfill() + cancellables.dropLast(10).forEach { $0?.cancel() } + } + } else if case .subscribed = event { + retryTriggerredAndSucceed.fulfill() + } + }) + await fulfillment(of: [maxSubscriptionReachedError, retryTriggerredAndSucceed], timeout: 5, enforceOrder: true) + withExtendedLifetime(cancellables, { }) + } + + private func makeOneSubscription( + id: String = UUID().uuidString, + onSubscriptionEvents: ((AppSyncSubscriptionEvent) -> Void)? + ) async throws -> AnyCancellable? { + let subscription = try await appSyncRealTimeClient?.subscribe( + id: id, + query: Self.appSyncQuery(with: self.subscriptionRequest) + ).sink(receiveValue: { + onSubscriptionEvents?($0) + }) + + return AnyCancellable { + subscription?.cancel() + Task { [weak self] in + try? await self?.appSyncRealTimeClient?.unsubscribe(id: id) + } + } + } + + private static func appSyncQuery( + with query: String, + variables: [String: JSONValue] = [:] + ) throws -> String { + let payload: JSONValue = .object([ + "query": .string(query), + "variables": (variables.isEmpty ? .null : .object(variables)) + ]) + let data = try JSONEncoder().encode(payload) + return String(data: data, encoding: .utf8)! + } + +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+ReachabilityTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+ReachabilityTests.swift index 5f5f38ff2f..6123db2a2e 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+ReachabilityTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPlugin+ReachabilityTests.swift @@ -36,7 +36,7 @@ class AWSAPICategoryPluginReachabilityTests: XCTestCase { let dependencies = AWSAPIPlugin.ConfigurationDependencies( pluginConfig: pluginConfig, authService: MockAWSAuthService(), - subscriptionConnectionFactory: AWSSubscriptionConnectionFactory(), + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactory(), logLevel: .error ) apiPlugin.configure(using: dependencies) @@ -64,7 +64,7 @@ class AWSAPICategoryPluginReachabilityTests: XCTestCase { let dependencies = AWSAPIPlugin.ConfigurationDependencies( pluginConfig: pluginConfig, authService: MockAWSAuthService(), - subscriptionConnectionFactory: AWSSubscriptionConnectionFactory(), + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactory(), logLevel: .error ) apiPlugin.configure(using: dependencies) @@ -92,7 +92,7 @@ class AWSAPICategoryPluginReachabilityTests: XCTestCase { let dependencies = AWSAPIPlugin.ConfigurationDependencies( pluginConfig: pluginConfig, authService: MockAWSAuthService(), - subscriptionConnectionFactory: AWSSubscriptionConnectionFactory(), + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactory(), logLevel: .error ) apiPlugin.configure(using: dependencies) diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPluginTestBase.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPluginTestBase.swift index 5889fbf89e..9d6fa2b283 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPluginTestBase.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AWSAPICategoryPluginTestBase.swift @@ -55,7 +55,7 @@ class AWSAPICategoryPluginTestBase: XCTestCase { let dependencies = AWSAPIPlugin.ConfigurationDependencies( pluginConfig: pluginConfig, authService: authService, - subscriptionConnectionFactory: AWSSubscriptionConnectionFactory(), + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactory(), logLevel: .error ) apiPlugin.configure(using: dependencies) diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeClientTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeClientTests.swift new file mode 100644 index 0000000000..279ca304d3 --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeClientTests.swift @@ -0,0 +1,490 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +import Combine +import Amplify +@_spi(WebSocket) import AWSPluginsCore +@testable import AWSAPIPlugin + +class AppSyncRealTimeClientTests: XCTestCase { + + func testSendRequestWithTimeout_withNoResponse_failedWithTimeOutError() async { + let timeout = 1.0 + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + + let requestFailedExpectation = expectation(description: "Request should be failed with error") + Task { + do { + try await appSyncClient.sendRequest(.connectionInit, timeout: timeout) + XCTFail("The operation should be failed with time out") + } catch { + let requestError = error as! AppSyncRealTimeRequest.Error + XCTAssert(requestError == .timeout) + requestFailedExpectation.fulfill() + } + } + await fulfillment(of: [requestFailedExpectation], timeout: timeout + 1) + } + + func testSendRequestWithTimeout_withCorrectResponse_succeed() async { + let timeout = 1.0 + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + + let finishExpectation = expectation(description: "Request finished successfully") + Task { + do { + try await appSyncClient.sendRequest(.connectionInit, timeout: timeout) + finishExpectation.fulfill() + } catch { + XCTFail("Operation shouldn't fail with error \(error)") + } + } + Task { + try await Task.sleep(nanoseconds: 80 * 1000) + await appSyncClient.subject.send(.init(id: nil, payload: nil, type: .connectionAck)) + } + await fulfillment(of: [finishExpectation], timeout: timeout + 1) + } + + func testSendRequestWithTimeout_withErrorResponse_transformLimitExceededError() async { + let timeout = 1.0 + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + + let limitExceededErrorExpectation = expectation(description: "Request should be failed with limitExceeded error") + let id = UUID().uuidString + Task { + do { + try await appSyncClient.sendRequest( + .start(.init(id: id, data: "", auth: nil)), + timeout: timeout + ) + XCTFail("Operation should be failed") + } catch { + let requestError = error as! AppSyncRealTimeRequest.Error + XCTAssertEqual(requestError, .limitExceeded) + limitExceededErrorExpectation.fulfill() + } + } + Task { + try await Task.sleep(nanoseconds: 80 * 1000) + await appSyncClient.subject.send(.init( + id: id, + payload: .object([ + "errors": .array([ + .object([ + "errorType": "LimitExceededError" + ]) + ]) + ]), + type: .error + )) + } + await fulfillment(of: [limitExceededErrorExpectation], timeout: timeout + 1) + } + + func testSendRequestWithTimeout_withErrorResponse_transformMaxSubscriptionsReachedError() async { + let timeout = 1.0 + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + let maxSubscriptionsReachedExpectation = + expectation(description: "Request should be failed with maxSubscriptionsReached error") + let id = UUID().uuidString + Task { + do { + try await appSyncClient.sendRequest( + .start(.init(id: id, data: "", auth: nil)), + timeout: timeout + ) + XCTFail("Operation should be failed") + } catch { + let requestError = error as! AppSyncRealTimeRequest.Error + XCTAssertEqual(requestError, .maxSubscriptionsReached) + maxSubscriptionsReachedExpectation.fulfill() + } + } + + Task { + try await Task.sleep(nanoseconds: 80 * 1000) + await appSyncClient.subject.send(.init( + id: id, + payload: .object([ + "errors": .array([ + .object([ + "errorType": "MaxSubscriptionsReachedError" + ]) + ]) + ]), + type: .error + )) + } + await fulfillment(of: [ + maxSubscriptionsReachedExpectation + ], timeout: timeout + 1) + } + + func testSendRequestWithTimeout_withErrorResponse_triggerErrorForUnknow() async { + let timeout = 1.0 + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + let triggerUnknownErrorExpectation = + expectation(description: "Request should trigger unknown errors") + let id = UUID().uuidString + Task { + do { + try await appSyncClient.sendRequest( + .start(.init(id: id, data: "", auth: nil)), + timeout: timeout + ) + } catch { + let requestError = error as! AppSyncRealTimeRequest.Error + guard case .unknown = requestError else { + XCTFail("The error should in unknown case") + return + } + triggerUnknownErrorExpectation.fulfill() + } + } + + Task { + try await Task.sleep(nanoseconds: 80 * 1000) + await appSyncClient.subject.send(.init( + id: id, + payload: .object([ + "errors": .array([ + .object([ + "errorType": "OtherError" + ]) + ]) + ]), + type: .error + )) + } + await fulfillment(of: [ + triggerUnknownErrorExpectation + ], timeout: timeout + 1) + } + + func testConnect_AppSyncRealTimeClient_triggersWebSocketConnection() async throws { + var cancellables = Set() + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + + let connectTriggered = expectation(description: "webSocket connect API should be invoked") + await mockWebSocketClient.setStateToConnected() + await mockWebSocketClient.actionSubject + .sink { action in + if case let .connect(param1, param2) = action { + XCTAssertEqual(param1, true) + XCTAssertEqual(param2, true) + connectTriggered.fulfill() + } else if case let .write(message) = action { + XCTAssertEqual(message, """ + {"type":"connection_init"} + """) + } else { + XCTFail("No other actions should be invoked") + } + } + .store(in: &cancellables) + Task { try await appSyncClient.connect() } + Task { + try await Task.sleep(nanoseconds: 50 * 1_000_000) + await mockWebSocketClient.subject.send(.connected) + try await Task.sleep(nanoseconds: 50 * 1_000_000) + await mockWebSocketClient.subject.send(.string(""" + {"type": "connection_ack", "payload": { "connectionTimeoutMs": 300000 }} + """)) + } + + await fulfillment(of: [connectTriggered], timeout: 1) + } + + func testDisconnect_AppSyncRealTimeClient_triggersWebSocketDisconnect() async throws { + var cancellables = Set() + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + + let disconnectTriggered = expectation(description: "webSocket disconnect API should be invoked") + await mockWebSocketClient.setStateToConnected() + await mockWebSocketClient.actionSubject + .sink { action in + if case .disconnect = action { + disconnectTriggered.fulfill() + } else { + XCTFail("No other actions should be invoked") + } + } + .store(in: &cancellables) + Task { await appSyncClient.disconnect() } + + await fulfillment(of: [disconnectTriggered], timeout: 1) + } + + func testUnsubscribe_withAppSyncRealTimeClientAlreadyConnected_triggersWebSocketStopEvent() async throws { + var cancellables = Set() + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + let id = UUID().uuidString + + let connectTriggered = expectation(description: "connect websocket") + let startTriggered = expectation(description: "webSocket start subscription") + let stopTriggered = expectation(description: "webSocket writing stop event to connection") + + await mockWebSocketClient.setStateToConnected() + + await mockWebSocketClient.actionSubject + .sink { action in + switch action { + case .connect: + Task { + await mockWebSocketClient.subject.send(.connected) + } + + case .write(let message): + guard let response = try? JSONDecoder().decode( + JSONValue.self, + from: message.data(using: .utf8)! + ) else { + XCTFail("Response should be able to decode to AppSyncRealTimeResponse") + return + } + + switch response.type?.stringValue { + case .some("stop"): + XCTAssertEqual(response.id?.stringValue, id) + stopTriggered.fulfill() + + case .some("start"): + XCTAssertEqual(response.id?.stringValue, id) + startTriggered.fulfill() + Task { + try await Task.sleep(nanoseconds: 80 * 1_000_000) + await mockWebSocketClient.subject.send(.string(""" + {"type": "start_ack", "id": "\(id)"} + """)) + try await Task.sleep(nanoseconds: 80 * 1_000_000) + try await appSyncClient.unsubscribe(id: id) + } + + case .some("connection_init"): + connectTriggered.fulfill() + Task { + try await Task.sleep(nanoseconds: 80 * 1_000_000) + await mockWebSocketClient.subject.send(.string(""" + {"type": "connection_ack", "payload": { "connectionTimeoutMs": 300000 }} + """)) + } + default: + XCTFail("No other message should be written") + } + + default: + XCTFail("No other actions should be invoked") + } + } + .store(in: &cancellables) + + Task { + _ = try await appSyncClient.subscribe(id: id, query: "") + } + + await fulfillment( + of: [connectTriggered, startTriggered, stopTriggered], + timeout: 2, + enforceOrder: true + ) + } + + func testUnsubscribe_withAppSyncRealTimeClientNotConnected_doesNotTriggerWebSocketStopEvent() async throws { + var cancellables = Set() + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + let id = UUID().uuidString + + let stopTriggered = expectation(description: "webSocket writing stop event to connection") + stopTriggered.isInverted = true + await mockWebSocketClient.actionSubject + .sink { action in + if case .write = action { + stopTriggered.fulfill() + } else { + XCTFail("No other actions should be invoked") + } + } + .store(in: &cancellables) + Task { try await appSyncClient.unsubscribe(id: id) } + + await fulfillment(of: [stopTriggered], timeout: 1) + } + + func testSubscribe_withAppSyncRealTimeClientAlreadyConnected_triggersWebSocketStartEvent() async throws { + var cancellables = Set() + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + let id = UUID().uuidString + let query = UUID().uuidString + + let startTriggered = expectation(description: "webSocket writing start event to connection") + + await mockWebSocketClient.setStateToConnected() + Task { + try await Task.sleep(nanoseconds: 80 * 1_000_000) + await mockWebSocketClient.subject.send(.connected) + try await Task.sleep(nanoseconds: 80 * 1_000_000) + await mockWebSocketClient.subject.send(.string(""" + {"type": "connection_ack", "payload": { "connectionTimeoutMs": 300000 }} + """)) + } + try await appSyncClient.connect() + await mockWebSocketClient.actionSubject + .sink { action in + switch action { + case .write(let message): + guard let response = try? JSONDecoder().decode( + JSONValue.self, + from: message.data(using: .utf8)! + ) else { + XCTFail("Response should be able to decode to AppSyncRealTimeResponse") + return + } + + if response.type?.stringValue == "start" { + XCTAssertEqual(response.id?.stringValue, id) + XCTAssertEqual(response.payload?.asObject?["data"]?.stringValue, query) + startTriggered.fulfill() + } else { + XCTFail("No other message should be written") + } + + default: + XCTFail("No other actions should be invoked") + } + } + .store(in: &cancellables) + + + Task { try await appSyncClient.subscribe(id: id, query: query) } + + await fulfillment(of: [startTriggered], timeout: 2) + } + + func testSubscribe_withAppSyncRealTimeClientNotConnected_triggersWebSocketStartEvent() async throws { + var cancellables = Set() + let mockWebSocketClient = MockWebSocketClient() + let mockAppSyncRequestInterceptor = MockAppSyncRequestInterceptor() + let appSyncClient = AppSyncRealTimeClient( + endpoint: URL(string: "https://example.com")!, + requestInterceptor: mockAppSyncRequestInterceptor, + webSocketClient: mockWebSocketClient + ) + let id = UUID().uuidString + let query = UUID().uuidString + + let connectTriggered = expectation(description: "webSocket connection is invoked") + let sendingConnectInit = expectation(description: "Sending connection_init message") + let startTriggered = expectation(description: "webSocket writing start event to connection") + await mockWebSocketClient.actionSubject + .sink { action in + switch action { + case .connect: + connectTriggered.fulfill() + case .write(let message): + guard let response = try? JSONDecoder().decode( + JSONValue.self, + from: message.data(using: .utf8)! + ) else { + XCTFail("Response should be able to decode to AppSyncRealTimeResponse") + return + } + + if response.type?.stringValue == "connection_init" { + sendingConnectInit.fulfill() + } else if response.type?.stringValue == "start" { + XCTAssertEqual(response.id?.stringValue, id) + XCTAssertEqual(response.payload?.asObject?["data"]?.stringValue, query) + startTriggered.fulfill() + } else { + XCTFail("No other message should be written") + } + + default: + XCTFail("No other actions should be invoked") + } + } + .store(in: &cancellables) + Task { try await appSyncClient.subscribe(id: id, query: query) } + Task { + try await Task.sleep(nanoseconds: 50 * 1_000_000) + await mockWebSocketClient.setStateToConnected() + try await Task.sleep(nanoseconds: 50 * 1_000_000) + await mockWebSocketClient.subject.send(.string(""" + {"type": "connection_ack", "payload": { "connectionTimeoutMs": 300000 }} + """)) + } + + await fulfillment(of: [ + connectTriggered, + sendingConnectInit, + startTriggered + ], timeout: 3, enforceOrder: true) + } +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift new file mode 100644 index 0000000000..6ab7af0692 --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift @@ -0,0 +1,215 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +@testable import AWSAPIPlugin + +class AppSyncRealTimeRequestAuthTests: XCTestCase { + let host = UUID().uuidString + let apiKey = UUID().uuidString + let date = UUID().uuidString + let id = UUID().uuidString + let data = UUID().uuidString + let token = UUID().uuidString + + var jsonEncoder = { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + return encoder + }() + + func testAppSyncRealTimeRequestAuth_encodeCognito() { + let cognitoAuth = AppSyncRealTimeRequestAuth.AuthToken(host: host, authToken: token) + XCTAssertEqual(toJson(cognitoAuth)?.shrink(), """ + { + "Authorization": "\(token)", + "host": "\(host)" + } + """.shrink()) + } + + func testAppSyncRealTimeRequestAuth_encodeApiKey() { + let apiKeyAuth = AppSyncRealTimeRequestAuth.ApiKey(host: host, apiKey: apiKey, amzDate: date) + XCTAssertEqual(toJson(apiKeyAuth)?.shrink(), """ + { + "host": "\(host)", + "x-amz-date": "\(date)", + "x-api-key": "\(apiKey)" + } + """.shrink()) + } + + func testAppSyncRealTimeRequestAuth_encodeIAM() { + let securityToken = UUID().uuidString + let iamAuth = AppSyncRealTimeRequestAuth.IAM( + host: host, + authToken: token, + securityToken: securityToken, + amzDate: date + ) + + XCTAssertEqual(toJson(iamAuth)?.shrink(), """ + { + "accept": "application\\/json, text\\/javascript", + "Authorization": "\(token)", + "content-encoding": "amz-1.0", + "content-type": "application\\/json; charset=UTF-8", + "host": "\(host)", + "x-amz-date": "\(date)", + "X-Amz-Security-Token": "\(securityToken)" + } + """.shrink()) + } + + func testAppSyncRealTimeRequestAuth_encodeStartRequestWithCognitoAuth() { + let auth: AppSyncRealTimeRequestAuth = .authToken(.init(host: host, authToken: token)) + let request = AppSyncRealTimeRequest.start( + .init(id: id, data: data, auth: auth) + ) + let requestJson = toJson(request) + XCTAssertEqual(requestJson?.shrink(), """ + { + "id": "\(id)", + "payload": { + "data": "\(data)", + "extensions": { + "authorization": { + "Authorization": "\(token)", + "host": "\(host)" + } + } + }, + "type": "start" + } + """.shrink()) + } + + func testAppSyncRealTimeRequestAuth_encodeStartRequestWithApiKeyAuth() { + let auth: AppSyncRealTimeRequestAuth = .apiKey(.init(host: host, apiKey: apiKey, amzDate: date)) + let request = AppSyncRealTimeRequest.start( + .init(id: id, data: data, auth: auth) + ) + let requestJson = toJson(request) + XCTAssertEqual(requestJson?.shrink(), """ + { + "id": "\(id)", + "payload": { + "data": "\(data)", + "extensions": { + "authorization": { + "host": "\(host)", + "x-amz-date": "\(date)", + "x-api-key": "\(apiKey)" + } + } + }, + "type": "start" + } + """.shrink()) + } + + func testAppSyncRealTimeRequestAuth_encodeStartRequestWithIAMAuth() { + let securityToken = UUID().uuidString + let iamAuth = AppSyncRealTimeRequestAuth.IAM( + host: host, + authToken: token, + securityToken: securityToken, + amzDate: date + ) + let request = AppSyncRealTimeRequest.start( + .init(id: id, data: data, auth: .iam(iamAuth)) + ) + let requestJson = toJson(request) + XCTAssertEqual(requestJson?.shrink(), """ + { + "id": "\(id)", + "payload": { + "data": "\(data)", + "extensions": { + "authorization": { + "accept": "application\\/json, text\\/javascript", + "Authorization": "\(token)", + "content-encoding": "amz-1.0", + "content-type": "application\\/json; charset=UTF-8", + "host": "\(host)", + "x-amz-date": "\(date)", + "X-Amz-Security-Token": "\(securityToken)" + } + } + }, + "type": "start" + } + """.shrink()) + } + + func testAppSyncRealTimeRequestAuth_URLQueryWithCognitoAuthHeader() { + let expectedURL = """ + https://example.com?\ + header=eyJBdXRob3JpemF0aW9uIjoiNDk4NTljN2MtNzQwNS00ZDU4LWFmZjctNTJiZ\ + TRiNDczNTU3IiwiaG9zdCI6ImV4YW1wbGUuY29tIn0%3D\ + &payload=e30%3D + """ + let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( + header: .authToken(.init( + host: "example.com", + authToken: "49859c7c-7405-4d58-aff7-52be4b473557" + )) + ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) + XCTAssertEqual(encodedURL.absoluteString, expectedURL) + } + + func testAppSyncRealTimeRequestAuth_URLQueryWithApiKeyAuthHeader() { + let expectedURL = """ + https://example.com?\ + header=eyJob3N0IjoiZXhhbXBsZS5jb20iLCJ4LWFtei1kYXRlIjoiOWUwZTJkZjktMmVlNy00NjU5L\ + TgzNjItMWM4ODFlMTE4YzlmIiwieC1hcGkta2V5IjoiNjVlMmZhY2EtOGUxZS00ZDM3LThkYzctNjQ0N\ + 2Q5Njk4MjQ3In0%3D\ + &payload=e30%3D + """ + let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( + header: .apiKey(.init( + host: "example.com", + apiKey: "65e2faca-8e1e-4d37-8dc7-6447d9698247", + amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f" + )) + ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) + XCTAssertEqual(encodedURL.absoluteString, expectedURL) + } + + func testAppSyncRealTimeRequestAuth_URLQueryWithIAMAuthHeader() { + + let expectedURL = """ + https://example.com?\ + header=eyJhY2NlcHQiOiJhcHBsaWNhdGlvblwvanNvbiwgdGV4dFwvamF2YXNjcmlwdCIsIkF1dGhvcml6YXR\ + pb24iOiJjOWRhZDg5Ny05MGQxLTRhNGMtYTVjOS0yYjM2YTI0NzczNWYiLCJjb250ZW50LWVuY29kaW5nIjoiY\ + W16LTEuMCIsImNvbnRlbnQtdHlwZSI6ImFwcGxpY2F0aW9uXC9qc29uOyBjaGFyc2V0PVVURi04IiwiaG9zdCI\ + 6ImV4YW1wbGUuY29tIiwieC1hbXotZGF0ZSI6IjllMGUyZGY5LTJlZTctNDY1OS04MzYyLTFjODgxZTExOGM5Z\ + iIsIlgtQW16LVNlY3VyaXR5LVRva2VuIjoiZTdlNjI2OWUtZmRhMS00ZGUwLThiZGItYmFhN2I2ZGQwYTBkIn0%3D\ + &payload=e30%3D + """ + let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( + header: .iam(.init( + host: "example.com", + authToken: "c9dad897-90d1-4a4c-a5c9-2b36a247735f", + securityToken: "e7e6269e-fda1-4de0-8bdb-baa7b6dd0a0d", + amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f")) + ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) + XCTAssertEqual(encodedURL.absoluteString, expectedURL) + } + + private func toJson(_ value: Encodable) -> String? { + return try? String(data: jsonEncoder.encode(value), encoding: .utf8) + } +} + +fileprivate extension String { + func shrink() -> String { + return self.replacingOccurrences(of: "\n", with: "") + .replacingOccurrences(of: " ", with: "") + } +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift new file mode 100644 index 0000000000..8c89c0a53a --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift @@ -0,0 +1,56 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +import Amplify +@testable import AWSAPIPlugin + +class APIKeyAuthInterceptorTests: XCTestCase { + + func testInterceptConnection_addApiKeySignatureInURLQuery() async { + let apiKey = UUID().uuidString + let interceptor = APIKeyAuthInterceptor(apiKey: apiKey) + let resultUrl = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) + guard let components = URLComponents(url: resultUrl, resolvingAgainstBaseURL: false) else { + XCTFail("Failed to decode decorated URL") + return + } + + let header = components.queryItems?.first { $0.name == "header" } + XCTAssertNotNil(header?.value) + let headerData = try! header?.value!.base64DecodedString().data(using: .utf8) + let decodedHeader = try! JSONDecoder().decode(JSONValue.self, from: headerData!) + XCTAssertEqual(decodedHeader["x-api-key"]?.stringValue, apiKey) + } + + func testInterceptRequest_appendAuthInfoInPayload() async { + let apiKey = UUID().uuidString + let interceptor = APIKeyAuthInterceptor(apiKey: apiKey) + let decoratedRequest = await interceptor.interceptRequest( + event: AppSyncRealTimeRequest.start(.init( + id: UUID().uuidString, + data: "", + auth: nil + )), + url: URL(string: "https://example.appsync-realtime-api.amazonaws.com")! + ) + guard case let .start(request) = decoratedRequest else { + XCTFail("Request should be a start request") + return + } + + XCTAssertNotNil(request.auth) + guard case let .apiKey(apiKeyInfo) = request.auth! else { + XCTFail("Auth should be api key") + return + } + + XCTAssertEqual(apiKeyInfo.apiKey, apiKey) + XCTAssertEqual(apiKeyInfo.host, "example.appsync-api.amazonaws.com") + } +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/AuthenticationTokenAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/AuthenticationTokenAuthInterceptorTests.swift deleted file mode 100644 index d9af096d54..0000000000 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/AuthenticationTokenAuthInterceptorTests.swift +++ /dev/null @@ -1,52 +0,0 @@ -// -// Copyright Amazon.com Inc. or its affiliates. -// All Rights Reserved. -// -// SPDX-License-Identifier: Apache-2.0 -// - -import XCTest -import Amplify -import AppSyncRealTimeClient -@testable import AWSAPIPlugin -@testable import AmplifyTestCommon - -class AuthenticationTokenAuthInterceptorTests: XCTestCase { - - func testAuthenticationTokenInterceptor() async throws { - let url = URL(string: "http://awssubscriptionurl.ca")! - let request = AppSyncConnectionRequest(url: url) - let interceptor = AuthenticationTokenAuthInterceptor(authTokenProvider: TestAuthTokenProvider()) - let interceptedRequest = await interceptor.interceptConnection(request, for: url) - - XCTAssertNotNil(interceptedRequest.url.query) - } - - func testDoesNotAddAuthHeaderIfTokenProviderReturnsError() async throws { - let url = URL(string: "http://awssubscriptionurl.ca")! - let request = AppSyncConnectionRequest(url: url) - let interceptor = AuthenticationTokenAuthInterceptor(authTokenProvider: TestFailingAuthTokenProvider()) - let interceptedRequest = await interceptor.interceptConnection(request, for: url) - - XCTAssertNil(interceptedRequest.url.query) - } -} - -// MARK: - Test token providers -private class TestAuthTokenProvider: AmplifyAuthTokenProvider { - - let authToken = "token" - - func getLatestAuthToken() async throws -> String { - authToken - } -} - -private class TestFailingAuthTokenProvider: AmplifyAuthTokenProvider { - - let authToken = "token" - - func getLatestAuthToken() async throws -> String { - throw "Token error" - } -} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift new file mode 100644 index 0000000000..4127f018fd --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift @@ -0,0 +1,125 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +import Amplify +@testable import AWSAPIPlugin +@testable @_spi(WebSocket) import AWSPluginsCore + +class CognitoAuthInterceptorTests: XCTestCase { + + func testInterceptConnection_withAuthTokenProvider_appendCorrectAuthHeaderToQuery() async { + let authTokenProvider = MockAuthTokenProvider() + let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) + + let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) + guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else { + XCTFail("Failed to get url components from decorated URL") + return + } + + guard let queryHeaderString = + try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString() + else { + XCTFail("Failed to extract header field from query string") + return + } + + guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!) + else { + XCTFail("Failed to decode query header to json object") + return + } + XCTAssertEqual(authTokenProvider.authToken, queryHeader.Authorization?.stringValue) + XCTAssertEqual("example.com", queryHeader.host?.stringValue) + } + + func testInterceptConnection_withAuthTokenProviderFailed_appendEmptyAuthHeaderToQuery() async { + let authTokenProvider = MockAuthTokenProviderFailed() + let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) + + let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) + guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else { + XCTFail("Failed to get url components from decorated URL") + return + } + + guard let queryHeaderString = + try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString() + else { + XCTFail("Failed to extract header field from query string") + return + } + + guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!) + else { + XCTFail("Failed to decode query header to json object") + return + } + XCTAssertEqual("", queryHeader.Authorization?.stringValue) + XCTAssertEqual("example.com", queryHeader.host?.stringValue) + } + + func testInterceptRequest_withAuthTokenProvider_appendCorrectAuthInfoToPayload() async { + let authTokenProvider = MockAuthTokenProvider() + let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) + let decoratedRequest = await interceptor.interceptRequest( + event: .start(.init(id: UUID().uuidString, data: UUID().uuidString, auth: nil)), + url: URL(string: "https://example.com")! + ) + + guard case let .start(decoratedAuth) = decoratedRequest else { + XCTFail("Failed to extract decoratedAuth info") + return + } + + guard case let .some(.authToken(authInfo)) = decoratedAuth.auth else { + XCTFail("Failed to extract authInfo from decoratedAuth") + return + } + + XCTAssertEqual(authTokenProvider.authToken, authInfo.authToken) + XCTAssertEqual("example.com", authInfo.host) + } + + func testInterceptRequest_withAuthTokenProviderFailed_appendEmptyAuthInfoToPayload() async { + let authTokenProvider = MockAuthTokenProviderFailed() + let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) + let decoratedRequest = await interceptor.interceptRequest( + event: .start(.init(id: UUID().uuidString, data: UUID().uuidString, auth: nil)), + url: URL(string: "https://example.com")! + ) + + guard case let .start(decoratedAuth) = decoratedRequest else { + XCTFail("Failed to extract decoratedAuth info") + return + } + + guard case let .some(.authToken(authInfo)) = decoratedAuth.auth else { + XCTFail("Failed to extract authInfo from decoratedAuth") + return + } + + XCTAssertEqual("", authInfo.authToken) + XCTAssertEqual("example.com", authInfo.host) + } +} + +fileprivate class MockAuthTokenProvider: AmplifyAuthTokenProvider { + let authToken = UUID().uuidString + func getLatestAuthToken() async throws -> String { + return authToken + } +} + +fileprivate class MockAuthTokenProviderFailed: AmplifyAuthTokenProvider { + let authToken = UUID().uuidString + func getLatestAuthToken() async throws -> String { + throw "Intended" + } +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/IAMAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/IAMAuthInterceptorTests.swift deleted file mode 100644 index c4ec4e0405..0000000000 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/IAMAuthInterceptorTests.swift +++ /dev/null @@ -1,114 +0,0 @@ -// -// Copyright Amazon.com Inc. or its affiliates. -// All Rights Reserved. -// -// SPDX-License-Identifier: Apache-2.0 -// - -import XCTest -import Amplify -import AppSyncRealTimeClient -@testable import AWSPluginsTestCommon -@testable import AWSAPIPlugin -@testable import AmplifyTestCommon - -class IAMAuthInterceptorTests: XCTestCase { - - func testIAMAuthenticationHeader() throws { - let expectedAdditionalHeaders = ["extra-header": "headerValue"] - let authHeader = IAMAuthenticationHeader(host: "host", - authorization: "auth", - securityToken: "token", - amzDate: "date", - accept: "accept", - contentEncoding: "encoding", - contentType: "type", - additionalHeaders: expectedAdditionalHeaders) - XCTAssertEqual(authHeader.authorization, "auth") - XCTAssertEqual(authHeader.securityToken, "token") - XCTAssertEqual(authHeader.amzDate, "date") - XCTAssertEqual(authHeader.accept, "accept") - XCTAssertEqual(authHeader.contentEncoding, "encoding") - XCTAssertEqual(authHeader.contentType, "type") - XCTAssertEqual(authHeader.additionalHeaders, expectedAdditionalHeaders) - } - - func testIAMAuthenticationHeaderEncodable() throws { - let authHeader = IAMAuthenticationHeader(host: "host", - authorization: "auth", - securityToken: "token", - amzDate: "date", - accept: "accept", - contentEncoding: "encoding", - contentType: "type", - additionalHeaders: nil) - - let encoder = JSONEncoder() - let serializedJSON = try encoder.encode(authHeader) - let decoder = JSONDecoder() - let json = try decoder.decode(JSONValue.self, from: serializedJSON) - - guard case let .object(jsonObject) = json else { - XCTFail("Failed to get JSON object") - return - } - XCTAssertEqual(jsonObject["host"], "host") - XCTAssertEqual(jsonObject[SubscriptionConstants.authorizationkey], "auth") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.iamSecurityTokenKey], "token") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.amzDate], "date") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.acceptKey], "accept") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.contentEncodingKey], "encoding") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.contentTypeKey], "type") - XCTAssertEqual(jsonObject.count, 7) - } - - func testIAMAuthenticationHeaderEncodableWithAdditionalHeaders() throws { - let expectedAdditionalHeaders = ["extra-header": "headerValue"] - let authHeader = IAMAuthenticationHeader(host: "host", - authorization: "auth", - securityToken: "token", - amzDate: "date", - accept: "accept", - contentEncoding: "encoding", - contentType: "type", - additionalHeaders: expectedAdditionalHeaders) - - let encoder = JSONEncoder() - let serializedJSON = try encoder.encode(authHeader) - let decoder = JSONDecoder() - let json = try decoder.decode(JSONValue.self, from: serializedJSON) - - guard case let .object(jsonObject) = json else { - XCTFail("Failed to get JSON object") - return - } - XCTAssertEqual(jsonObject["host"], "host") - XCTAssertEqual(jsonObject[SubscriptionConstants.authorizationkey], "auth") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.iamSecurityTokenKey], "token") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.amzDate], "date") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.acceptKey], "accept") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.contentEncodingKey], "encoding") - XCTAssertEqual(jsonObject[RealtimeProviderConstants.contentTypeKey], "type") - XCTAssertEqual(jsonObject["extra-header"], "headerValue") - XCTAssertEqual(jsonObject.count, 8) - } - - func testInterceptConnection() async { - let mockAuthService = MockAWSAuthService() - let interceptor = IAMAuthInterceptor(mockAuthService.getCredentialsProvider(), region: "us-west-2") - let url = URL(string: "https://abc.appsync-api.us-west-2.amazonaws.com/graphql")! - let signer = MockAWSSignatureV4Signer() - guard let authHeader = await interceptor.getAuthHeader(url, with: "payload", signer: signer) else { - XCTFail("Could not get authHeader") - return - } - - XCTAssertNotNil(authHeader.authorization) - XCTAssertNotNil(authHeader.securityToken) - XCTAssertNotNil(authHeader.amzDate) - XCTAssertEqual(authHeader.accept, "application/json, text/javascript") - XCTAssertEqual(authHeader.contentEncoding, "amz-1.0") - XCTAssertEqual(authHeader.contentType, "application/json; charset=UTF-8") - XCTAssertNil(authHeader.additionalHeaders) - } -} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Mocks/MockSubscription.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Mocks/MockSubscription.swift index b9b2116ab7..2ba9f97779 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Mocks/MockSubscription.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Mocks/MockSubscription.swift @@ -7,20 +7,21 @@ import Foundation -@testable import AWSAPIPlugin import Amplify +import Combine +@testable import AWSAPIPlugin +@_spi(WebSocket) import AWSPluginsCore -import AWSPluginsCore -import AppSyncRealTimeClient +struct MockSubscriptionConnectionFactory: AppSyncRealTimeClientFactoryProtocol { + -struct MockSubscriptionConnectionFactory: SubscriptionConnectionFactory { typealias OnGetOrCreateConnection = ( AWSAPICategoryPluginConfiguration.EndpointConfig, - URLRequest, + URL, AWSAuthServiceBehavior, AWSAuthorizationType?, APIAuthProviderFactory - ) throws -> SubscriptionConnection + ) async throws -> AppSyncRealTimeClientProtocol let onGetOrCreateConnection: OnGetOrCreateConnection @@ -28,45 +29,115 @@ struct MockSubscriptionConnectionFactory: SubscriptionConnectionFactory { self.onGetOrCreateConnection = onGetOrCreateConnection } - func getOrCreateConnection( + func getAppSyncRealTimeClient( for endpointConfig: AWSAPICategoryPluginConfiguration.EndpointConfig, - urlRequest: URLRequest, + endpoint: URL, authService: AWSAuthServiceBehavior, authType: AWSAuthorizationType?, apiAuthProviderFactory: APIAuthProviderFactory - ) throws -> SubscriptionConnection { - try onGetOrCreateConnection(endpointConfig, urlRequest, authService, authType, apiAuthProviderFactory) + ) async throws -> AppSyncRealTimeClientProtocol { + try await onGetOrCreateConnection(endpointConfig, endpoint, authService, authType, apiAuthProviderFactory) } +} + +class MockAppSyncRealTimeClient: AppSyncRealTimeClientProtocol { + + private let subject = PassthroughSubject() + + func subscribe(id: String, query: String) async throws -> AnyPublisher { + defer { + + Task { + try await Task.sleep(seconds: 0.25) + subject.send(.subscribing) + try await Task.sleep(seconds: 0.45) + subject.send(.subscribed) + } + } + return subject.eraseToAnyPublisher() + } + + func unsubscribe(id: String) async throws { + try await Task.sleep(seconds: 0.45) + subject.send(.unsubscribed) + } + + func connect() async throws { } + + func disconnectWhenIdel() async { } + + func disconnect() async { } + + func triggerEvent(_ event: AppSyncSubscriptionEvent) { + subject.send(event) + } + + static func waitForSubscirbing() async throws { + try await Task.sleep(seconds: 0.3) + } + + static func waitForSubscirbed() async throws { + try await Task.sleep(seconds: 0.5) + } + + static func waitForUnsubscirbed() async throws { + try await Task.sleep(seconds: 0.5) + } +} + +class MockAppSyncRequestInterceptor: AppSyncRequestInterceptor { + func interceptRequest(event: AppSyncRealTimeRequest, url: URL) async -> AppSyncRealTimeRequest { + return event + } } -struct MockSubscriptionConnection: SubscriptionConnection { - typealias OnSubscribe = ( - String, - [String: Any?]?, - @escaping SubscriptionEventHandler - ) -> SubscriptionItem +actor MockWebSocketClient: AppSyncWebSocketClientProtocol { + enum State { + case none + case connected + } + + enum Action { + case connect(Bool, Bool) + case disconnect + case write(String) + } + + var actionSubject = PassthroughSubject() + var subject = PassthroughSubject() + var state: State + + var isConnected: Bool { + state == .connected + } + + var publisher: AnyPublisher { + subject.eraseToAnyPublisher() + } - typealias OnUnsubscribe = (SubscriptionItem) -> Void + init() { + self.state = .none + } - let onSubscribe: OnSubscribe - let onUnsubscribe: OnUnsubscribe + deinit { + subject.send(completion: .finished) + actionSubject.send(completion: .finished) + } - init(onSubscribe: @escaping OnSubscribe, onUnsubscribe: @escaping OnUnsubscribe) { - self.onSubscribe = onSubscribe - self.onUnsubscribe = onUnsubscribe + func connect(autoConnectOnNetworkStatusChange: Bool, autoRetryOnConnectionFailure: Bool) { + actionSubject.send(.connect(autoConnectOnNetworkStatusChange, autoRetryOnConnectionFailure)) } - func subscribe( - requestString: String, - variables: [String: Any?]?, - eventHandler: @escaping SubscriptionEventHandler - ) -> SubscriptionItem { - onSubscribe(requestString, variables, eventHandler) + func disconnect() { + actionSubject.send(.disconnect) } - func unsubscribe(item: SubscriptionItem) { - onUnsubscribe(item) + func write(message: String) throws { + actionSubject.send(.write(message)) } + func setStateToConnected() { + self.state = .connected + } } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionOperationCancelTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionOperationCancelTests.swift index b54fe33c70..95fc5b8e63 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionOperationCancelTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionOperationCancelTests.swift @@ -10,8 +10,7 @@ import XCTest @testable import Amplify @testable import AWSAPIPlugin @testable import AmplifyTestCommon -@testable import AppSyncRealTimeClient -@testable import AWSPluginsCore +@testable @_spi(WebSocket) import AWSPluginsCore @testable import AWSPluginsTestCommon // swiftlint:disable:next type_name @@ -30,7 +29,7 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { let testBody = Data() let testPath = "testPath" - func setUp(subscriptionConnectionFactory: SubscriptionConnectionFactory) async { + func setUp(mockAppSyncRealTimeClientFactory: MockSubscriptionConnectionFactory) async { apiPlugin = AWSAPIPlugin() let authService = MockAWSAuthService() @@ -50,7 +49,7 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { let dependencies = AWSAPIPlugin.ConfigurationDependencies( pluginConfig: pluginConfig, authService: authService, - subscriptionConnectionFactory: subscriptionConnectionFactory, + appSyncRealTimeClientFactory: mockAppSyncRealTimeClientFactory, logLevel: .error ) apiPlugin.configure(using: dependencies) @@ -69,15 +68,9 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { func testCancelSendsCompletion() async { let mockSubscriptionConnectionFactory = MockSubscriptionConnectionFactory(onGetOrCreateConnection: { _, _, _, _, _ in - return MockSubscriptionConnection(onSubscribe: { (_, _, eventHandler) -> SubscriptionItem in - let item = SubscriptionItem(requestString: "", variables: nil, eventHandler: { _, _ in - }) - eventHandler(.connection(.connecting), item) - return item - }, onUnsubscribe: {_ in - }) + MockAppSyncRealTimeClient() }) - await setUp(subscriptionConnectionFactory: mockSubscriptionConnectionFactory) + await setUp(mockAppSyncRealTimeClientFactory: mockSubscriptionConnectionFactory) let request = GraphQLRequest(apiName: apiName, document: testDocument, @@ -93,10 +86,10 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { case .connecting: print("1/3 Subscription is connecting") receivedValueConnecting.fulfill() + case .connected: + break case .disconnected: break - default: - XCTFail("Unexpected value on value listener: \(state)") } default: XCTFail("Unexpected value on on value listener: \(value)") @@ -116,7 +109,7 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { let receivedFailure = expectation(description: "Received failure") receivedFailure.isInverted = true let receivedValueDisconnected = expectation(description: "Received value for disconnected") - + _ = operation.subscribe(inProcessListener: { value in switch value { case .connection(let state): @@ -126,8 +119,8 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { case .disconnected: print("2/3 Subscription is disconnected") receivedValueDisconnected.fulfill() - default: - XCTFail("Unexpected value on value listener: \(state)") + case .connected: + break } default: XCTFail("Unexpected value on on value listener: \(value)") @@ -148,7 +141,7 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { await fulfillment( of: [receivedCompletion, receivedFailure, receivedValueDisconnected], - timeout: 5 + timeout: 1 ) } @@ -157,7 +150,7 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { throw APIError.invalidConfiguration("something went wrong", "", nil) }) - await setUp(subscriptionConnectionFactory: mockSubscriptionConnectionFactory) + await setUp(mockAppSyncRealTimeClientFactory: mockSubscriptionConnectionFactory) let request = GraphQLRequest(apiName: apiName, document: testDocument, @@ -201,16 +194,10 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { let connectionCreation = expectation(description: "connection factory called") let mockSubscriptionConnectionFactory = MockSubscriptionConnectionFactory(onGetOrCreateConnection: { _, _, _, _, _ in connectionCreation.fulfill() - return MockSubscriptionConnection(onSubscribe: { (_, _, eventHandler) -> SubscriptionItem in - let item = SubscriptionItem(requestString: "", variables: nil, eventHandler: { _, _ in - }) - eventHandler(.connection(.connecting), item) - return item - }, onUnsubscribe: {_ in - }) + return MockAppSyncRealTimeClient() }) - await setUp(subscriptionConnectionFactory: mockSubscriptionConnectionFactory) + await setUp(mockAppSyncRealTimeClientFactory: mockSubscriptionConnectionFactory) let request = GraphQLRequest(apiName: apiName, document: testDocument, @@ -252,7 +239,7 @@ class AWSGraphQLSubscriptionOperationCancelTests: XCTestCase { XCTAssert(operation.isCancelled) await fulfillment( of: [receivedCompletion, receivedFailure], - timeout: 5 + timeout: 1 ) } } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionTaskRunnerCancelTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionTaskRunnerCancelTests.swift index 583886dff3..a06001515f 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionTaskRunnerCancelTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/AWSGraphQLSubscriptionTaskRunnerCancelTests.swift @@ -10,7 +10,6 @@ import XCTest @testable import Amplify @testable import AWSAPIPlugin @testable import AmplifyTestCommon -@testable import AppSyncRealTimeClient @testable import AWSPluginsCore @testable import AWSPluginsTestCommon @@ -30,7 +29,7 @@ class AWSGraphQLSubscriptionTaskRunnerCancelTests: XCTestCase { let testBody = Data() let testPath = "testPath" - func setUp(subscriptionConnectionFactory: SubscriptionConnectionFactory) async { + func setUp(appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol) async { apiPlugin = AWSAPIPlugin() let authService = MockAWSAuthService() @@ -50,7 +49,7 @@ class AWSGraphQLSubscriptionTaskRunnerCancelTests: XCTestCase { let dependencies = AWSAPIPlugin.ConfigurationDependencies( pluginConfig: pluginConfig, authService: authService, - subscriptionConnectionFactory: subscriptionConnectionFactory, + appSyncRealTimeClientFactory: appSyncRealTimeClientFactory, logLevel: .error ) apiPlugin.configure(using: dependencies) @@ -67,17 +66,12 @@ class AWSGraphQLSubscriptionTaskRunnerCancelTests: XCTestCase { } } - func testCancelSendsCompletion() async { + func testCancelSendsCompletion() async throws { let mockSubscriptionConnectionFactory = MockSubscriptionConnectionFactory(onGetOrCreateConnection: { _, _, _, _, _ in - return MockSubscriptionConnection(onSubscribe: { (_, _, eventHandler) -> SubscriptionItem in - let item = SubscriptionItem(requestString: "", variables: nil, eventHandler: { _, _ in - }) - eventHandler(.connection(.connecting), item) - return item - }, onUnsubscribe: {_ in - }) + return MockAppSyncRealTimeClient() }) - await setUp(subscriptionConnectionFactory: mockSubscriptionConnectionFactory) + + await setUp(appSyncRealTimeClientFactory: mockSubscriptionConnectionFactory) let request = GraphQLRequest(apiName: apiName, document: testDocument, @@ -114,15 +108,16 @@ class AWSGraphQLSubscriptionTaskRunnerCancelTests: XCTestCase { } await fulfillment(of: [receivedValueConnecting], timeout: 1) subscriptionEvents.cancel() + try await MockAppSyncRealTimeClient.waitForUnsubscirbed() await fulfillment(of: [receivedValueDisconnected, receivedCompletion, receivedFailure], timeout: 1) } func testFailureOnConnection() async { - let mockSubscriptionConnectionFactory = MockSubscriptionConnectionFactory(onGetOrCreateConnection: { _, _, _, _, _ in + let mockAppSyncRealTimeClientFactory = MockSubscriptionConnectionFactory(onGetOrCreateConnection: { _, _, _, _, _ in throw APIError.invalidConfiguration("something went wrong", "", nil) }) - await setUp(subscriptionConnectionFactory: mockSubscriptionConnectionFactory) + await setUp(appSyncRealTimeClientFactory: mockAppSyncRealTimeClientFactory) let request = GraphQLRequest(apiName: apiName, document: testDocument, @@ -154,16 +149,10 @@ class AWSGraphQLSubscriptionTaskRunnerCancelTests: XCTestCase { let connectionCreation = expectation(description: "connection factory called") let mockSubscriptionConnectionFactory = MockSubscriptionConnectionFactory(onGetOrCreateConnection: { _, _, _, _, _ in connectionCreation.fulfill() - return MockSubscriptionConnection(onSubscribe: { (_, _, eventHandler) -> SubscriptionItem in - let item = SubscriptionItem(requestString: "", variables: nil, eventHandler: { _, _ in - }) - eventHandler(.connection(.connecting), item) - return item - }, onUnsubscribe: {_ in - }) + return MockAppSyncRealTimeClient() }) - await setUp(subscriptionConnectionFactory: mockSubscriptionConnectionFactory) + await setUp(appSyncRealTimeClientFactory: mockSubscriptionConnectionFactory) let request = GraphQLRequest(apiName: apiName, document: testDocument, diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeCombineTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeCombineTests.swift index b21b7f6528..cbe8ad220c 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeCombineTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeCombineTests.swift @@ -12,7 +12,6 @@ import Amplify @testable import AmplifyTestCommon @testable import AWSAPIPlugin @_implementationOnly import AmplifyAsyncTesting -import AppSyncRealTimeClient class GraphQLSubscribeCombineTests: OperationTestBase { @@ -32,10 +31,7 @@ class GraphQLSubscribeCombineTests: OperationTestBase { var receivedDataValueSuccess: XCTestExpectation! var receivedDataValueError: XCTestExpectation! - // Handles to the subscription item and event handler used to make mock calls into the - // subscription system - var subscriptionItem: SubscriptionItem! - var subscriptionEventHandler: SubscriptionEventHandler! + var mockAppSyncRealTimeClient: MockAppSyncRealTimeClient? var connectionStateSink: AnyCancellable? var subscriptionDataSink: AnyCancellable? @@ -57,6 +53,21 @@ class GraphQLSubscribeCombineTests: OperationTestBase { try setUpMocksAndSubscriptionItems() } + override func tearDown() async throws { + self.sink?.cancel() + self.connectionStateSink?.cancel() + self.subscriptionDataSink?.cancel() + self.onSubscribeInvoked = nil + self.receivedCompletionFailure = nil + self.receivedCompletionSuccess = nil + self.receivedDataValueError = nil + self.receivedDataValueSuccess = nil + self.receivedStateValueConnected = nil + self.receivedStateValueConnecting = nil + self.receivedStateValueDisconnected = nil + try await super.tearDown() + } + func waitForSubscriptionExpectations() async { await fulfillment(of: [receivedCompletionSuccess, receivedCompletionFailure, @@ -72,15 +83,18 @@ class GraphQLSubscribeCombineTests: OperationTestBase { receivedDataValueError.isInverted = true let testJSON: JSONValue = ["foo": true] - let testData = Data(#"{"data": {"foo": true}}"#.utf8) try await subscribe(expecting: testJSON) await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(.object([ + "data": .object([ + "foo": .boolean(true) + ]) + ]))) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } @@ -93,9 +107,9 @@ class GraphQLSubscribeCombineTests: OperationTestBase { try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } @@ -110,31 +124,39 @@ class GraphQLSubscribeCombineTests: OperationTestBase { try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.failed("Error"), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + mockAppSyncRealTimeClient?.triggerEvent(.error(["Error"])) await waitForSubscriptionExpectations() } func testDecodingError() async throws { - let testData = Data(#"{"data": {"foo": true}, "errors": []}"#.utf8) receivedCompletionFailure.isInverted = true receivedDataValueSuccess.isInverted = true try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(.object([ + "data": .object([ + "foo": .boolean(true) + ]), + "errors": .array([]) + ]))) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } func testMultipleSuccessValues() async throws { let testJSON: JSONValue = ["foo": true] - let testData = Data(#"{"data": {"foo": true}}"#.utf8) + let testData: JSONValue = .object([ + "data": .object([ + "foo": .boolean(true) + ]) + ]) receivedCompletionFailure.isInverted = true receivedDataValueError.isInverted = true receivedDataValueSuccess.expectedFulfillmentCount = 2 @@ -142,30 +164,39 @@ class GraphQLSubscribeCombineTests: OperationTestBase { try await subscribe(expecting: testJSON) await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient?.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } func testMixedSuccessAndErrorValues() async throws { - let successfulTestData = Data(#"{"data": {"foo": true}}"#.utf8) - let invalidTestData = Data(#"{"data": {"foo": true}, "errors": []}"#.utf8) + let successfulTestData: JSONValue = .object([ + "data": .object([ + "foo": .boolean(true) + ]) + ]) + let invalidTestData: JSONValue = .object([ + "data": .object([ + "foo": .boolean(true) + ]), + "errors": .array([]) + ]) receivedCompletionFailure.isInverted = true receivedDataValueSuccess.expectedFulfillmentCount = 2 try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(successfulTestData), subscriptionItem) - subscriptionEventHandler(.data(invalidTestData), subscriptionItem) - subscriptionEventHandler(.data(successfulTestData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(successfulTestData)) + mockAppSyncRealTimeClient?.triggerEvent(.data(invalidTestData)) + mockAppSyncRealTimeClient?.triggerEvent(.data(successfulTestData)) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } @@ -176,22 +207,13 @@ class GraphQLSubscribeCombineTests: OperationTestBase { /// self.subscriptionItem and self.subscriptionEventHandler, then fulfills /// self.onSubscribeInvoked func setUpMocksAndSubscriptionItems() throws { - let onSubscribe: MockSubscriptionConnection.OnSubscribe = { - requestString, variables, eventHandler in - let item = SubscriptionItem( - requestString: requestString, - variables: variables, - eventHandler: eventHandler - ) - - self.subscriptionItem = item - self.subscriptionEventHandler = eventHandler - self.onSubscribeInvoked.fulfill() - return item - } + defer { onSubscribeInvoked.fulfill() } + let mockAppSyncRealTimeClient = MockAppSyncRealTimeClient() + + self.mockAppSyncRealTimeClient = mockAppSyncRealTimeClient let onGetOrCreateConnection: MockSubscriptionConnectionFactory.OnGetOrCreateConnection = { _, _, _, _, _ in - MockSubscriptionConnection(onSubscribe: onSubscribe, onUnsubscribe: { _ in }) + return mockAppSyncRealTimeClient } try setUpPluginForSubscriptionResponse(onGetOrCreateConnection: onGetOrCreateConnection) diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTaskTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTaskTests.swift index 7568eff886..d632e131c7 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTaskTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTaskTests.swift @@ -12,7 +12,6 @@ import Amplify @testable import AmplifyTestCommon @testable import AWSAPIPlugin @_implementationOnly import AmplifyAsyncTesting -import AppSyncRealTimeClient class GraphQLSubscribeTasksTests: OperationTestBase { @@ -30,14 +29,10 @@ class GraphQLSubscribeTasksTests: OperationTestBase { var receivedDataValueSuccess: XCTestExpectation! var receivedDataValueError: XCTestExpectation! - // Handles to the subscription item and event handler used to make mock calls into the - // subscription system - var subscriptionItem: SubscriptionItem! - var subscriptionEventHandler: SubscriptionEventHandler! - var connectionStateSink: AnyCancellable? var subscriptionDataSink: AnyCancellable? var expectedCompletionFailureError: APIError? + var mockAppSyncRealTimeClient: MockAppSyncRealTimeClient? override func setUp() async throws { try await super.setUp() @@ -56,6 +51,23 @@ class GraphQLSubscribeTasksTests: OperationTestBase { try setUpMocksAndSubscriptionItems() } + override func tearDown() async throws { + connectionStateSink?.cancel() + subscriptionDataSink?.cancel() + + onSubscribeInvoked = nil + receivedCompletionFailure = nil + receivedCompletionSuccess = nil + receivedStateValueConnected = nil + receivedStateValueConnecting = nil + receivedStateValueDisconnected = nil + + receivedDataValueError = nil + receivedDataValueSuccess = nil + mockAppSyncRealTimeClient = nil + try await super.tearDown() + } + func waitForSubscriptionExpectations() async { await fulfillment( of: [ @@ -76,15 +88,19 @@ class GraphQLSubscribeTasksTests: OperationTestBase { receivedDataValueError.isInverted = true let testJSON: JSONValue = ["foo": true] - let testData = Data(#"{"data": {"foo": true}}"#.utf8) + let testData: JSONValue = [ + "data": [ + "foo": true + ] + ] try await subscribe(expecting: testJSON) await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } @@ -96,10 +112,9 @@ class GraphQLSubscribeTasksTests: OperationTestBase { try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } @@ -114,25 +129,9 @@ class GraphQLSubscribeTasksTests: OperationTestBase { try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.failed(ConnectionProviderError.limitExceeded(nil)), subscriptionItem) - expectedCompletionFailureError = APIError.operationError("", "", ConnectionProviderError.limitExceeded(nil)) - await waitForSubscriptionExpectations() - } - - func testConnectionErrorWithSubscriptionError() async throws { - receivedCompletionSuccess.isInverted = true - receivedStateValueConnected.isInverted = true - receivedStateValueDisconnected.isInverted = true - receivedDataValueSuccess.isInverted = true - receivedDataValueError.isInverted = true - - try await subscribe() - await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.failed(ConnectionProviderError.subscription("", nil)), subscriptionItem) - expectedCompletionFailureError = APIError.operationError("", "", ConnectionProviderError.subscription("", nil)) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + mockAppSyncRealTimeClient?.triggerEvent(.error([AppSyncRealTimeRequest.Error.limitExceeded])) + expectedCompletionFailureError = APIError.operationError("", "", AppSyncRealTimeRequest.Error.limitExceeded) await waitForSubscriptionExpectations() } @@ -146,13 +145,18 @@ class GraphQLSubscribeTasksTests: OperationTestBase { try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.failed(ConnectionProviderError.unauthorized), subscriptionItem) - expectedCompletionFailureError = APIError.operationError("", "", ConnectionProviderError.unauthorized) + let unauthorizedError = GraphQLError(message: "", extensions: ["errorType": "Unauthorized"]) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + mockAppSyncRealTimeClient?.triggerEvent(.error([unauthorizedError])) + expectedCompletionFailureError = APIError.operationError( + "Subscription item event failed with error: Unauthorized", + "", + GraphQLResponseError.error([unauthorizedError]) + ) await waitForSubscriptionExpectations() } - func testConnectionErrorWithConnectionProviderConnectionError() async throws { + func testConnectionErrorWithAppSyncConnectionError() async throws { receivedCompletionSuccess.isInverted = true receivedStateValueConnected.isInverted = true receivedStateValueDisconnected.isInverted = true @@ -162,31 +166,35 @@ class GraphQLSubscribeTasksTests: OperationTestBase { try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.failed(ConnectionProviderError.connection), subscriptionItem) - expectedCompletionFailureError = APIError.networkError("", nil, URLError(.networkConnectionLost)) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + mockAppSyncRealTimeClient?.triggerEvent(.error([URLError(URLError.Code(rawValue: 400))])) + expectedCompletionFailureError = APIError.operationError("", "", URLError(URLError.Code(rawValue: 400))) await waitForSubscriptionExpectations() } func testDecodingError() async throws { - let testData = Data(#"{"data": {"foo": true}, "errors": []}"#.utf8) + let testData: JSONValue = [ + "data": [ "foo": true ], + "errors": [] + ] receivedCompletionFailure.isInverted = true receivedDataValueSuccess.isInverted = true try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } func testMultipleSuccessValues() async throws { let testJSON: JSONValue = ["foo": true] - let testData = Data(#"{"data": {"foo": true}}"#.utf8) + let testData: JSONValue = [ + "data": [ "foo": true ] + ] receivedCompletionFailure.isInverted = true receivedDataValueError.isInverted = true @@ -195,18 +203,23 @@ class GraphQLSubscribeTasksTests: OperationTestBase { try await subscribe(expecting: testJSON) await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient?.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } func testMixedSuccessAndErrorValues() async throws { - let successfulTestData = Data(#"{"data": {"foo": true}}"#.utf8) - let invalidTestData = Data(#"{"data": {"foo": true}, "errors": []}"#.utf8) + let successfulTestData: JSONValue = [ + "data": [ "foo": true ] + ] + let invalidTestData: JSONValue = [ + "data": [ "foo": true ], + "errors": [] + ] receivedCompletionFailure.isInverted = true receivedDataValueSuccess.expectedFulfillmentCount = 2 @@ -214,12 +227,12 @@ class GraphQLSubscribeTasksTests: OperationTestBase { try await subscribe() await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(successfulTestData), subscriptionItem) - subscriptionEventHandler(.data(invalidTestData), subscriptionItem) - subscriptionEventHandler(.data(successfulTestData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient?.triggerEvent(.data(successfulTestData)) + mockAppSyncRealTimeClient?.triggerEvent(.data(invalidTestData)) + mockAppSyncRealTimeClient?.triggerEvent(.data(successfulTestData)) + mockAppSyncRealTimeClient?.triggerEvent(.unsubscribed) await waitForSubscriptionExpectations() } @@ -230,25 +243,12 @@ class GraphQLSubscribeTasksTests: OperationTestBase { /// self.subscriptionItem and self.subscriptionEventHandler, then fulfills /// self.onSubscribeInvoked func setUpMocksAndSubscriptionItems() throws { - let onSubscribe: MockSubscriptionConnection.OnSubscribe = { - requestString, variables, eventHandler in - let item = SubscriptionItem( - requestString: requestString, - variables: variables, - eventHandler: eventHandler - ) - - self.subscriptionItem = item - self.subscriptionEventHandler = eventHandler - self.onSubscribeInvoked.fulfill() - return item + defer { self.onSubscribeInvoked.fulfill() } + let mockAppSyncRealTimeClient = MockAppSyncRealTimeClient() + self.mockAppSyncRealTimeClient = mockAppSyncRealTimeClient + try setUpPluginForSubscriptionResponse { _, _, _, _, _ in + mockAppSyncRealTimeClient } - - let onGetOrCreateConnection: MockSubscriptionConnectionFactory.OnGetOrCreateConnection = { _, _, _, _, _ in - MockSubscriptionConnection(onSubscribe: onSubscribe, onUnsubscribe: { _ in }) - } - - try setUpPluginForSubscriptionResponse(onGetOrCreateConnection: onGetOrCreateConnection) } /// Calls `Amplify.API.subscribe` with a request made from a generic document, and returns @@ -313,23 +313,17 @@ extension APIError: Equatable { (.pluginError, .pluginError): return true case (.operationError(_, _, let lhs), .operationError(_, _, let rhs)): - if let lhs = lhs as? ConnectionProviderError, let rhs = rhs as? ConnectionProviderError { - switch (lhs, rhs) { - case (.connection, .connection), - (.jsonParse, .jsonParse), - (.limitExceeded, .limitExceeded), - (.subscription, .subscription), - (.unauthorized, .unauthorized), - (.unknown, .unknown): - return true - default: - return false - } - } else if lhs == nil && rhs == nil { - return true - } else { - return false + switch (lhs, rhs) { + case let (lhs, rhs) as (URLError, URLError): + return lhs == rhs + case let (lhs, rhs) as (GraphQLResponseError, GraphQLResponseError): + return lhs.errorDescription == rhs.errorDescription + case let (lhs, rhs) as (AppSyncRealTimeRequest.Error, AppSyncRealTimeRequest.Error): + return lhs == rhs + case (.none, .none): return true + default: return false } + case (.networkError(_, _, let lhs), .networkError(_, _, let rhs)): if let lhs = lhs as? URLError, let rhs = rhs as? URLError { return lhs.code == rhs.code diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTests.swift index 5624f14c0b..4b28c62cc2 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/GraphQLSubscribeTests.swift @@ -9,7 +9,6 @@ import XCTest @testable import Amplify @testable import AmplifyTestCommon @testable import AWSAPIPlugin -import AppSyncRealTimeClient class GraphQLSubscribeTests: OperationTestBase { @@ -28,10 +27,7 @@ class GraphQLSubscribeTests: OperationTestBase { var receivedSubscriptionEventData: XCTestExpectation! var receivedSubscriptionEventError: XCTestExpectation! - // Handles to the subscription item and event handler used to make mock calls into the - // subscription system - var subscriptionItem: SubscriptionItem! - var subscriptionEventHandler: SubscriptionEventHandler! + var mockAppSyncRealTimeClient: MockAppSyncRealTimeClient! override func setUp() async throws { try await super.setUp() @@ -50,6 +46,30 @@ class GraphQLSubscribeTests: OperationTestBase { try setUpMocksAndSubscriptionItems() } + override func tearDown() async throws { + onSubscribeInvoked = nil + receivedCompletionFinish = nil + receivedCompletionFailure = nil + receivedConnected = nil + receivedDisconnected = nil + receivedSubscriptionEventData = nil + receivedSubscriptionEventError = nil + + mockAppSyncRealTimeClient = nil + try await super.tearDown() + } + + private func waitForExpectations(timeout: TimeInterval) async { + await fulfillment(of: [ + receivedCompletionFinish, + receivedCompletionFailure, + receivedConnected, + receivedDisconnected, + receivedSubscriptionEventData, + receivedSubscriptionEventError + ], timeout: timeout) + } + /// Lifecycle test /// /// When: @@ -62,9 +82,11 @@ class GraphQLSubscribeTests: OperationTestBase { /// - The value handler is invoked with a successfully decoded value /// - The value handler is invoked with a disconnection message /// - The completion handler is invoked with a normal termination - func testHappyPath() throws { + func testHappyPath() async throws { let testJSON: JSONValue = ["foo": true] - let testData = Data(#"{"data": {"foo": true}}"#.utf8) + let testData: JSONValue = [ + "data": [ "foo": true ] + ] receivedCompletionFinish.shouldTrigger = true receivedCompletionFailure.shouldTrigger = false receivedConnected.shouldTrigger = true @@ -73,14 +95,14 @@ class GraphQLSubscribeTests: OperationTestBase { receivedSubscriptionEventError.shouldTrigger = false subscribe(expecting: testJSON) - wait(for: [onSubscribeInvoked], timeout: 0.05) + await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient.triggerEvent(.unsubscribed) - waitForExpectations(timeout: 0.05) + await waitForExpectations(timeout: 0.05) } /// Lifecycle test @@ -94,7 +116,7 @@ class GraphQLSubscribeTests: OperationTestBase { /// - The value handler is not invoked with with a data value /// - The value handler is invoked with a disconnection message /// - The completion handler is invoked with a normal termination - func testConnectionWithNoData() throws { + func testConnectionWithNoData() async throws { receivedCompletionFinish.shouldTrigger = true receivedCompletionFailure.shouldTrigger = false receivedConnected.shouldTrigger = true @@ -103,13 +125,13 @@ class GraphQLSubscribeTests: OperationTestBase { receivedSubscriptionEventError.shouldTrigger = false subscribe() - wait(for: [onSubscribeInvoked], timeout: 0.05) + await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient.triggerEvent(.unsubscribed) - waitForExpectations(timeout: 0.05) + await waitForExpectations(timeout: 0.05) } /// Lifecycle test @@ -122,7 +144,7 @@ class GraphQLSubscribeTests: OperationTestBase { /// - The value handler is not invoked with with a data value /// - The value handler is invoked with a disconnection message /// - The completion handler is invoked with an error termination - func testConnectionError() throws { + func testConnectionError() async throws { receivedCompletionFinish.shouldTrigger = false receivedCompletionFailure.shouldTrigger = true receivedConnected.shouldTrigger = false @@ -131,12 +153,12 @@ class GraphQLSubscribeTests: OperationTestBase { receivedSubscriptionEventError.shouldTrigger = false subscribe() - wait(for: [onSubscribeInvoked], timeout: 0.05) + await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.failed("Error"), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + mockAppSyncRealTimeClient.triggerEvent(.error(["Error"])) - waitForExpectations(timeout: 0.05) + await waitForExpectations(timeout: 0.05) } /// Lifecycle test @@ -151,8 +173,11 @@ class GraphQLSubscribeTests: OperationTestBase { /// - The value handler is invoked with an error /// - The value handler is invoked with a disconnection message /// - The completion handler is invoked with a normal termination - func testDecodingError() throws { - let testData = Data(#"{"data": {"foo": true}, "errors": []}"#.utf8) + func testDecodingError() async throws { + let testData: JSONValue = [ + "data": [ "foo": true ], + "errors": [ ] + ] receivedCompletionFinish.shouldTrigger = true receivedCompletionFailure.shouldTrigger = false receivedConnected.shouldTrigger = true @@ -161,19 +186,23 @@ class GraphQLSubscribeTests: OperationTestBase { receivedSubscriptionEventError.shouldTrigger = true subscribe() - wait(for: [onSubscribeInvoked], timeout: 0.05) + await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient.triggerEvent(.unsubscribed) - waitForExpectations(timeout: 0.05) + await waitForExpectations(timeout: 0.05) } - func testMultipleSuccessValues() throws { + func testMultipleSuccessValues() async throws { let testJSON: JSONValue = ["foo": true] - let testData = Data(#"{"data": {"foo": true}}"#.utf8) + let testData: JSONValue = [ + "data": [ + "foo": true + ] + ] receivedCompletionFinish.shouldTrigger = true receivedCompletionFailure.shouldTrigger = false receivedConnected.shouldTrigger = true @@ -183,20 +212,29 @@ class GraphQLSubscribeTests: OperationTestBase { receivedSubscriptionEventError.shouldTrigger = false subscribe(expecting: testJSON) - wait(for: [onSubscribeInvoked], timeout: 0.05) + await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.data(testData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient.triggerEvent(.data(testData)) + mockAppSyncRealTimeClient.triggerEvent(.unsubscribed) - waitForExpectations(timeout: 0.05) + await waitForExpectations(timeout: 0.05) } - func testMixedSuccessAndErrorValues() throws { - let successfulTestData = Data(#"{"data": {"foo": true}}"#.utf8) - let invalidTestData = Data(#"{"data": {"foo": true}, "errors": []}"#.utf8) + func testMixedSuccessAndErrorValues() async throws { + let successfulTestData: JSONValue = [ + "data": [ + "foo": true + ] + ] + let invalidTestData: JSONValue = [ + "data": [ + "foo": true + ], + "errors": [] + ] receivedCompletionFinish.shouldTrigger = true receivedCompletionFailure.shouldTrigger = false receivedConnected.shouldTrigger = true @@ -206,16 +244,16 @@ class GraphQLSubscribeTests: OperationTestBase { receivedSubscriptionEventError.shouldTrigger = true subscribe() - wait(for: [onSubscribeInvoked], timeout: 0.05) + await fulfillment(of: [onSubscribeInvoked], timeout: 0.05) - subscriptionEventHandler(.connection(.connecting), subscriptionItem) - subscriptionEventHandler(.connection(.connected), subscriptionItem) - subscriptionEventHandler(.data(successfulTestData), subscriptionItem) - subscriptionEventHandler(.data(invalidTestData), subscriptionItem) - subscriptionEventHandler(.data(successfulTestData), subscriptionItem) - subscriptionEventHandler(.connection(.disconnected), subscriptionItem) + try await MockAppSyncRealTimeClient.waitForSubscirbing() + try await MockAppSyncRealTimeClient.waitForSubscirbed() + mockAppSyncRealTimeClient.triggerEvent(.data(successfulTestData)) + mockAppSyncRealTimeClient.triggerEvent(.data(invalidTestData)) + mockAppSyncRealTimeClient.triggerEvent(.data(successfulTestData)) + mockAppSyncRealTimeClient.triggerEvent(.unsubscribed) - waitForExpectations(timeout: 0.05) + await waitForExpectations(timeout: 0.05) } // MARK: - Utilities @@ -224,25 +262,13 @@ class GraphQLSubscribeTests: OperationTestBase { /// self.subscriptionItem and self.subscriptionEventHandler, then fulfills /// self.onSubscribeInvoked func setUpMocksAndSubscriptionItems() throws { - let onSubscribe: MockSubscriptionConnection.OnSubscribe = { - requestString, variables, eventHandler in - let item = SubscriptionItem( - requestString: requestString, - variables: variables, - eventHandler: eventHandler - ) - - self.subscriptionItem = item - self.subscriptionEventHandler = eventHandler - self.onSubscribeInvoked.fulfill() - return item - } + defer { onSubscribeInvoked.fulfill() } + let mockAppSyncRealTimeClient = MockAppSyncRealTimeClient() - let onGetOrCreateConnection: MockSubscriptionConnectionFactory.OnGetOrCreateConnection = { _, _, _, _, _ in - MockSubscriptionConnection(onSubscribe: onSubscribe, onUnsubscribe: { _ in }) + self.mockAppSyncRealTimeClient = mockAppSyncRealTimeClient + try setUpPluginForSubscriptionResponse { _, _, _, _, _ in + mockAppSyncRealTimeClient } - - try setUpPluginForSubscriptionResponse(onGetOrCreateConnection: onGetOrCreateConnection) } /// Calls `Amplify.API.subscribe` with a request made from a generic document, and returns diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift index 6c5496d13e..26d9014091 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Operation/OperationTestBase.swift @@ -24,7 +24,7 @@ class OperationTestBase: XCTestCase { func setUpPlugin( sessionFactory: URLSessionBehaviorFactory? = nil, - subscriptionConnectionFactory: SubscriptionConnectionFactory? = nil, + appSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol? = nil, endpointType: AWSAPICategoryPluginEndpointType ) throws { apiPlugin = AWSAPIPlugin(sessionFactory: sessionFactory) @@ -42,7 +42,7 @@ class OperationTestBase: XCTestCase { configurationValues: configurationValues, apiAuthProviderFactory: APIAuthProviderFactory(), authService: MockAWSAuthService(), - subscriptionConnectionFactory: subscriptionConnectionFactory + appSyncRealTimeClientFactory: appSyncRealTimeClientFactory ) apiPlugin.configure(using: dependencies) @@ -68,12 +68,11 @@ class OperationTestBase: XCTestCase { func setUpPluginForSubscriptionResponse( onGetOrCreateConnection: @escaping MockSubscriptionConnectionFactory.OnGetOrCreateConnection ) throws { - let subscriptionConnectionFactory = MockSubscriptionConnectionFactory( - onGetOrCreateConnection: onGetOrCreateConnection - ) + + let appSyncRealTimeClientFactory = MockSubscriptionConnectionFactory(onGetOrCreateConnection: onGetOrCreateConnection) try setUpPlugin( - subscriptionConnectionFactory: subscriptionConnectionFactory, + appSyncRealTimeClientFactory: appSyncRealTimeClientFactory, endpointType: .graphQL ) } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/SubscriptionFactory/AppSyncRealTimeClientFactoryTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/SubscriptionFactory/AppSyncRealTimeClientFactoryTests.swift new file mode 100644 index 0000000000..7156ac7678 --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/SubscriptionFactory/AppSyncRealTimeClientFactoryTests.swift @@ -0,0 +1,37 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +@testable import AWSAPIPlugin + +class AppSyncRealTimeClientFactoryTests: XCTestCase { + + func testAppSyncRealTimeEndpoint_withAWSAppSyncDomain_returnCorrectRealtimeDomain() { + let appSyncEndpoint = URL(string: "https://abc.appsync-api.amazonaws.com/graphql")! + XCTAssertEqual( + AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(appSyncEndpoint), + URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql") + ) + } + + func testAppSyncRealTimeEndpoint_withAWSAppSyncRealTimeDomain_returnTheSameDomain() { + let appSyncEndpoint = URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql")! + XCTAssertEqual( + AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(appSyncEndpoint), + URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql") + ) + } + + func testAppSyncRealTimeEndpoint_withCustomDomain_returnCorrectRealtimePath() { + let appSyncEndpoint = URL(string: "https://test.example.com/graphql")! + XCTAssertEqual( + AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(appSyncEndpoint), + URL(string: "https://test.example.com/graphql/realtime") + ) + } +} diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/Options/AWSAuthConfirmSignUpOptions.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/Options/AWSAuthConfirmSignUpOptions.swift index 5d6505bfb6..c661c54855 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/Options/AWSAuthConfirmSignUpOptions.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/Options/AWSAuthConfirmSignUpOptions.swift @@ -11,7 +11,11 @@ public struct AWSAuthConfirmSignUpOptions { public let metadata: [String: String]? - public init(metadata: [String: String]? = nil) { + public let forceAliasCreation: Bool? + + public init(metadata: [String: String]? = nil, + forceAliasCreation: Bool? = nil) { self.metadata = metadata + self.forceAliasCreation = forceAliasCreation } } diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Utils/ConfirmSignUpInput+Amplify.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Utils/ConfirmSignUpInput+Amplify.swift index 5590fccce0..12f057a7fb 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Utils/ConfirmSignUpInput+Amplify.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Support/Utils/ConfirmSignUpInput+Amplify.swift @@ -13,6 +13,7 @@ extension ConfirmSignUpInput { confirmationCode: String, clientMetadata: [String: String]?, asfDeviceId: String?, + forceAliasCreation: Bool?, environment: UserPoolEnvironment ) { @@ -37,6 +38,7 @@ extension ConfirmSignUpInput { clientId: configuration.clientId, clientMetadata: clientMetadata, confirmationCode: confirmationCode, + forceAliasCreation: forceAliasCreation, secretHash: secretHash, userContextData: userContextData, username: username) diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Task/AWSAuthConfirmSignUpTask.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Task/AWSAuthConfirmSignUpTask.swift index 6a03e246b5..c280dd400d 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Task/AWSAuthConfirmSignUpTask.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Task/AWSAuthConfirmSignUpTask.swift @@ -33,11 +33,13 @@ class AWSAuthConfirmSignUpTask: AuthConfirmSignUpTask, DefaultLogger { for: request.username, credentialStoreClient: authEnvironment.credentialsClient) let metadata = (request.options.pluginOptions as? AWSAuthConfirmSignUpOptions)?.metadata + let forceAliasCreation = (request.options.pluginOptions as? AWSAuthConfirmSignUpOptions)?.forceAliasCreation let client = try userPoolEnvironment.cognitoUserPoolFactory() let input = ConfirmSignUpInput(username: request.username, confirmationCode: request.code, clientMetadata: metadata, asfDeviceId: asfDeviceId, + forceAliasCreation: forceAliasCreation, environment: userPoolEnvironment) _ = try await client.confirmSignUp(input: input) log.verbose("Received success") diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ResolverTests/SignUpState/ConfirmSignUpInputTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ResolverTests/SignUpState/ConfirmSignUpInputTests.swift index 52e54270d4..f926817dc6 100644 --- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ResolverTests/SignUpState/ConfirmSignUpInputTests.swift +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ResolverTests/SignUpState/ConfirmSignUpInputTests.swift @@ -31,6 +31,7 @@ class ConfirmSignUpInputTests: XCTestCase { confirmationCode: "123", clientMetadata: [:], asfDeviceId: "asdfDeviceId", + forceAliasCreation: nil, environment: environment) XCTAssertNotNil(confirmSignUpInput.secretHash) @@ -55,6 +56,7 @@ class ConfirmSignUpInputTests: XCTestCase { confirmationCode: "123", clientMetadata: [:], asfDeviceId: nil, + forceAliasCreation: nil, environment: environment) XCTAssertNil(confirmSignUpInput.secretHash) diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/TaskTests/ClientBehaviorTests/SignUp/AWSAuthConfirmSignUpAPITests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/TaskTests/ClientBehaviorTests/SignUp/AWSAuthConfirmSignUpAPITests.swift index fe5c8124c2..92812c342d 100644 --- a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/TaskTests/ClientBehaviorTests/SignUp/AWSAuthConfirmSignUpAPITests.swift +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/TaskTests/ClientBehaviorTests/SignUp/AWSAuthConfirmSignUpAPITests.swift @@ -24,7 +24,9 @@ class AWSAuthConfirmSignUpAPITests: BasePluginTest { func testSuccessfulSignUp() async throws { self.mockIdentityProvider = MockIdentityProvider( - mockConfirmSignUpResponse: { _ in + mockConfirmSignUpResponse: { request in + XCTAssertNil(request.clientMetadata) + XCTAssertNil(request.forceAliasCreation) return .init() } ) @@ -47,11 +49,14 @@ class AWSAuthConfirmSignUpAPITests: BasePluginTest { mockConfirmSignUpResponse: { request in XCTAssertNotNil(request.clientMetadata) XCTAssertEqual(request.clientMetadata?["key"], "value") + XCTAssertEqual(request.forceAliasCreation, true) return .init() } ) - let pluginOptions = AWSAuthConfirmSignUpOptions(metadata: ["key": "value"]) + let pluginOptions = AWSAuthConfirmSignUpOptions( + metadata: ["key": "value"], + forceAliasCreation: true) let options = AuthConfirmSignUpRequest.Options(pluginOptions: pluginOptions) let result = try await self.plugin.confirmSignUp( for: "jeffb", diff --git a/AmplifyPlugins/Core/AWSPluginsCore/API/AppSyncErrorType.swift b/AmplifyPlugins/Core/AWSPluginsCore/API/AppSyncErrorType.swift index 11daad656f..edcf20ddef 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/API/AppSyncErrorType.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/API/AppSyncErrorType.swift @@ -36,7 +36,7 @@ public enum AppSyncErrorType: Equatable { self = .conditionalCheck case AppSyncErrorType.conflictUnhandledErrorString: self = .conflictUnhandled - case AppSyncErrorType.unauthorizedErrorString: + case _ where value.contains(AppSyncErrorType.unauthorizedErrorString): self = .unauthorized case AppSyncErrorType.operationDisabledErrorString: self = .operationDisabled diff --git a/AmplifyPlugins/Core/AWSPluginsCore/ServiceConfiguration/AmplifyAWSServiceConfiguration.swift b/AmplifyPlugins/Core/AWSPluginsCore/ServiceConfiguration/AmplifyAWSServiceConfiguration.swift index f73d295a6d..ca8f6cdbf9 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/ServiceConfiguration/AmplifyAWSServiceConfiguration.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/ServiceConfiguration/AmplifyAWSServiceConfiguration.swift @@ -15,7 +15,7 @@ import Amplify public class AmplifyAWSServiceConfiguration { /// - Tag: AmplifyAWSServiceConfiguration.amplifyVersion - public static let amplifyVersion = "2.27.2" + public static let amplifyVersion = "2.27.3" /// - Tag: AmplifyAWSServiceConfiguration.platformName public static let platformName = "amplify-swift" diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/AmplifyNetworkMonitor.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/AmplifyNetworkMonitor.swift new file mode 100644 index 0000000000..23eb1ec4e2 --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/AmplifyNetworkMonitor.swift @@ -0,0 +1,51 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Network +import Combine + +@_spi(WebSocket) +public final class AmplifyNetworkMonitor { + + public enum State { + case none + case online + case offline + } + + private let monitor: NWPathMonitor + + private let subject = PassthroughSubject() + + public var publisher: AnyPublisher<(State, State), Never> { + subject.scan((.none, .none)) { previous, next in + (previous.1, next) + }.eraseToAnyPublisher() + } + + public init(on interface: NWInterface.InterfaceType? = nil) { + monitor = interface.map(NWPathMonitor.init(requiredInterfaceType:)) ?? NWPathMonitor() + monitor.pathUpdateHandler = { [weak self] path in + self?.subject.send(path.status == .satisfied ? .online : .offline) + } + + monitor.start(queue: DispatchQueue( + label: "com.amazonaws.amplify.ios.network.websocket.monitor", + qos: .userInitiated + )) + } + + public func updateState(_ nextState: State) { + subject.send(nextState) + } + + deinit { + subject.send(completion: .finished) + monitor.cancel() + } +} diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/RetryWithJitter.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/RetryWithJitter.swift new file mode 100644 index 0000000000..9da51cb03f --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/RetryWithJitter.swift @@ -0,0 +1,72 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +@_spi(WebSocket) +public actor RetryWithJitter { + public enum Error: Swift.Error { + case maxRetryExceeded([Swift.Error]) + } + let base: UInt + let max: UInt + var retryCount: UInt = 0 + + init(base: UInt = 25, max: UInt = 6400) { + self.base = base + self.max = max + } + + // using FullJitter backoff strategy + // ref: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + // Returns: retry backoff time interval in millisecond + func next() -> UInt { + let expo = min(max, powerOf2(count: retryCount) * base) + retryCount += 1 + return UInt.random(in: 0..( + maxRetryCount: UInt = 8, + shouldRetryOnError: (Swift.Error) -> Bool = { _ in true }, + _ operation: @escaping () async throws -> Output + ) async throws -> Output { + let retryWithJitter = RetryWithJitter() + func recursive(retryCount: UInt, cause: [Swift.Error]) async -> Result { + if retryCount == maxRetryCount { + return .failure(RetryWithJitter.Error.maxRetryExceeded(cause)) + } + + let backoffInterval = retryCount == 0 ? 0 : await retryWithJitter.next() + do { + try await Task.sleep(nanoseconds: UInt64(backoffInterval) * 1_000_000) + return .success(try await operation()) + } catch { + print("[RetryWithJitter] operation failed with error \(error), retrying(\(retryCount))") + if shouldRetryOnError(error) { + return await recursive(retryCount: retryCount + 1, cause: cause + [error]) + } else { + return .failure(error) + } + } + } + return try await recursive(retryCount: 0, cause: []).get() + } +} + +fileprivate func powerOf2(count: UInt) -> UInt { + count == 0 + ? 1 + : 2 * powerOf2(count: count - 1) +} diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift new file mode 100644 index 0000000000..d3a76af540 --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift @@ -0,0 +1,370 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Amplify +import Combine + +/** + WebSocketClient wraps URLSessionWebSocketTask and offers + an abstraction of the data stream in the form of WebSocketEvent. + */ +@_spi(WebSocket) +public final actor WebSocketClient: NSObject { + public enum Error: Swift.Error { + case connectionLost + case connectionCancelled + } + + /// WebSocket server endpoint + private let url: URL + /// WebSocket subprotocols + private let protocols: [String] + /// Interceptor for appending additional info before makeing the connection + private var interceptor: WebSocketInterceptor? + /// Internal wriable WebSocketEvent data stream + private let subject = PassthroughSubject() + + private let retryWithJitter = RetryWithJitter() + + /// Network monitor provide notification of device network status + private let networkMonitor: WebSocketNetworkMonitorProtocol + + /// Cancellables bind with client life cycle + private var cancelables = Set() + /// The underlying URLSessionWebSocketTask + private var connection: URLSessionWebSocketTask? { + willSet { + self.connection?.cancel(with: .goingAway, reason: nil) + } + } + + /// A flag indicating whether to automatically update the connection upon network status updates + private var autoConnectOnNetworkStatusChange: Bool + /// A flag indicating whether to automatically retry on connection failure + private var autoRetryOnConnectionFailure: Bool + /// Data stream for downstream subscribers to engage with + public var publisher: AnyPublisher { + self.subject.eraseToAnyPublisher() + } + + public var isConnected: Bool { + self.connection?.state == .running + } + + /** + Creates a WebSocketClient. + + - Parameters: + - url: WebSocket server endpoint + - protocols: WebSocket subprotocols, for header `Sec-WebSocket-Protocol` + - interceptor: An optional interceptor for additional info before establishing the connection + - networkMonitor: Provides network status notifications + */ + public init( + url: URL, + protocols: [String] = [], + interceptor: WebSocketInterceptor? = nil, + networkMonitor: WebSocketNetworkMonitorProtocol = AmplifyNetworkMonitor() + ) { + self.url = Self.useWebSocketProtocolScheme(url: url) + self.protocols = protocols + self.interceptor = interceptor + self.autoConnectOnNetworkStatusChange = false + self.autoRetryOnConnectionFailure = false + self.networkMonitor = networkMonitor + super.init() + /** + The network monitor and retries should have a longer lifespan compared to the connection itself. + This ensures that when the network goes offline or the connection drops, + the network monitor can initiate a reconnection once the network is back online. + */ + Task { await self.startNetworkMonitor() } + Task { await self.retryOnConnectionFailure() } + } + + deinit { + self.subject.send(completion: .finished) + self.autoConnectOnNetworkStatusChange = false + self.autoRetryOnConnectionFailure = false + cancelables = Set() + } + + /** + Connect to WebSocket server. + - Parameters: + - autoConnectOnNetworkStatusChange: + A flag indicating whether this connection should be automatically updated when the network status changes. + - autoRetryOnConnectionFailure: + A flag indicating whether this connection should attampt to retry upon failure. + */ + public func connect( + autoConnectOnNetworkStatusChange: Bool = false, + autoRetryOnConnectionFailure: Bool = false + ) async { + guard self.connection?.state != .running else { + log.debug("[WebSocketClient] WebSocket is already in connecting state") + return + } + + log.debug("[WebSocketClient] WebSocket about to connect") + self.autoConnectOnNetworkStatusChange = autoConnectOnNetworkStatusChange + self.autoRetryOnConnectionFailure = autoRetryOnConnectionFailure + + await self.createConnectionAndRead() + } + + /** + Disconnect from WebSocket server. + + This will halt all automatic processes and attempt to gracefully close the connection. + */ + public func disconnect() { + guard self.connection?.state == .running else { + log.debug("[WebSocketClient] client should be in connected state to trigger disconnect") + return + } + + self.autoConnectOnNetworkStatusChange = false + self.autoRetryOnConnectionFailure = false + self.connection?.cancel(with: .goingAway, reason: nil) + } + + /** + Write text data to WebSocket server. + - Parameters: + - message: text message in String + */ + public func write(message: String) async throws { + log.debug("[WebSocketClient] WebSocket write message string: \(message)") + try await self.connection?.send(.string(message)) + } + + /** + Write binary data to WebSocket server. + - Parameters: + - message: binary message in Data + */ + public func write(message: Data) async throws { + log.debug("[WebSocketClient] WebSocket write message data: \(message)") + try await self.connection?.send(.data(message)) + } + + private func createWebSocketConnection() async -> URLSessionWebSocketTask { + let urlSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil) + let decoratedURL = (await self.interceptor?.interceptConnection(url: self.url)) ?? self.url + return urlSession.webSocketTask(with: decoratedURL, protocols: self.protocols) + } + + private func createConnectionAndRead() async { + log.debug("[WebSocketClient] Creating new connection and starting read") + self.connection = await createWebSocketConnection() + + // Perform reading from a WebSocket in a separate task recursively to avoid blocking the execution. + Task { await self.startReadMessage() } + + self.connection?.resume() + } + + /** + Recusively read WebSocket data frames and publish to data stream. + */ + private func startReadMessage() async { + guard let connection = self.connection else { + log.debug("[WebSocketClient] WebSocket connection doesn't exist") + return + } + + if connection.state == .canceling || connection.state == .completed { + log.debug("[WebSocketClient] WebSocket connection state is \(connection.state). Failed to read websocket message") + return + } + + do { + let message = try await connection.receive() + log.debug("[WebSocketClient] WebSocket received message: \(String(describing: message))") + switch message { + case .data(let data): + subject.send(.data(data)) + case .string(let string): + subject.send(.string(string)) + @unknown default: + break + } + } catch { + if connection.state == .running { + subject.send(.error(error)) + } else { + log.debug("[WebSocketClient] read message failed with connection state \(connection.state), error \(error)") + } + } + + await self.startReadMessage() + } +} + +// MARK: - URLSession delegate +extension WebSocketClient: URLSessionWebSocketDelegate { + nonisolated public func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didOpenWithProtocol protocol: String? + ) { + log.debug("[WebSocketClient] Websocket connected") + self.subject.send(.connected) + } + + nonisolated public func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, + reason: Data? + ) { + log.debug("[WebSocketClient] Websocket disconnected") + self.subject.send(.disconnected(closeCode, reason.flatMap { String(data: $0, encoding: .utf8) })) + } + + nonisolated public func urlSession( + _ session: URLSession, + task: URLSessionTask, + didCompleteWithError error: Swift.Error? + ) { + guard let error else { + log.debug("[WebSocketClient] URLSession didComplete") + return + } + + log.debug("[WebSocketClient] URLSession didCompleteWithError: \(error))") + + let nsError = error as NSError + switch (nsError.domain, nsError.code) { + case (NSURLErrorDomain.self, NSURLErrorNetworkConnectionLost), // connection lost + (NSPOSIXErrorDomain.self, Int(ECONNABORTED)): // background to foreground + self.subject.send(.error(WebSocketClient.Error.connectionLost)) + Task { [weak self] in + await self?.networkMonitor.updateState(.offline) + } + case (NSURLErrorDomain.self, NSURLErrorCancelled): + log.debug("Skipping NSURLErrorCancelled error") + self.subject.send(.error(WebSocketClient.Error.connectionCancelled)) + default: + self.subject.send(.error(error)) + } + } +} + +// MARK: - network reachability +extension WebSocketClient { + /// Monitor network status. Disconnect or reconnect when the network drops or comes back online. + private func startNetworkMonitor() { + networkMonitor.publisher.sink(receiveValue: { stateChange in + Task { [weak self] in + await self?.onNetworkStateChange(stateChange) + } + }) + .store(in: &cancelables) + } + + private func onNetworkStateChange( + _ stateChange: (AmplifyNetworkMonitor.State, AmplifyNetworkMonitor.State) + ) async { + guard self.autoConnectOnNetworkStatusChange == true else { + return + } + + switch stateChange { + case (.online, .offline): + log.debug("[WebSocketClient] NetworkMonitor - Device went offline") + self.connection?.cancel(with: .invalid, reason: nil) + self.subject.send(.disconnected(.invalid, nil)) + case (.offline, .online): + log.debug("[WebSocketClient] NetworkMonitor - Device back online") + await self.createConnectionAndRead() + default: + break + } + } +} + +// MARK: - auto retry on connection failure +extension WebSocketClient { + private func retryOnConnectionFailure() { + subject.map { event -> URLSessionWebSocketTask.CloseCode? in + guard case .disconnected(let closeCode, _) = event else { + return nil + } + return closeCode + } + .compactMap { $0 } + .sink(receiveCompletion: { _ in }) { closeCode in + Task { [weak self] in await self?.retryOnCloseCode(closeCode) } + } + .store(in: &cancelables) + + self.resetRetryCountOnConnected() + } + + private func resetRetryCountOnConnected() { + subject.filter { + if case .connected = $0 { + return true + } + return false + } + .sink(receiveCompletion: { _ in }) { _ in + Task { [weak self] in + await self?.retryWithJitter.reset() + } + } + .store(in: &cancelables) + } + + private func retryOnCloseCode(_ closeCode: URLSessionWebSocketTask.CloseCode) async { + guard self.autoRetryOnConnectionFailure == true else { + return + } + + switch closeCode { + case .internalServerError: + let delayInMs = await retryWithJitter.next() + Task { [weak self] in + try await Task.sleep(nanoseconds: UInt64(delayInMs) * 1_000_000) + await self?.createConnectionAndRead() + } + default: break + } + + } +} + +extension WebSocketClient { + static func useWebSocketProtocolScheme(url: URL) -> URL { + guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + return url + } + urlComponents.scheme = urlComponents.scheme == "http" ? "ws" : "wss" + return urlComponents.url ?? url + } +} + +extension WebSocketClient: DefaultLogger { + public static var log: Logger { + Amplify.Logging.logger(forNamespace: String(describing: self)) + } + + public nonisolated var log: Logger { Self.log } +} + +extension WebSocketClient: Resettable { + public func reset() async { + self.subject.send(completion: .finished) + self.autoConnectOnNetworkStatusChange = false + self.autoRetryOnConnectionFailure = false + cancelables = Set() + } +} diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketEvent.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketEvent.swift new file mode 100644 index 0000000000..35c101dd6e --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketEvent.swift @@ -0,0 +1,18 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +@_spi(WebSocket) +public enum WebSocketEvent { + case connected + case disconnected(URLSessionWebSocketTask.CloseCode, String?) + case data(Data) + case string(String) + case error(Error) +} diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift new file mode 100644 index 0000000000..a53ec3b950 --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift @@ -0,0 +1,14 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +@_spi(WebSocket) +public protocol WebSocketInterceptor { + func interceptConnection(url: URL) async -> URL +} diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketNetworkMonitorProtocol.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketNetworkMonitorProtocol.swift new file mode 100644 index 0000000000..3966e7ab9d --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketNetworkMonitorProtocol.swift @@ -0,0 +1,18 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Combine + +@_spi(WebSocket) +public protocol WebSocketNetworkMonitorProtocol { + var publisher: AnyPublisher<(AmplifyNetworkMonitor.State, AmplifyNetworkMonitor.State), Never> { get } + func updateState(_ nextState: AmplifyNetworkMonitor.State) async +} + +extension AmplifyNetworkMonitor: WebSocketNetworkMonitorProtocol { } diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/LocalWebSocketServer.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/LocalWebSocketServer.swift new file mode 100644 index 0000000000..1dc0fbd948 --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/LocalWebSocketServer.swift @@ -0,0 +1,105 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation +import Network + +class LocalWebSocketServer { + let portNumber = UInt16.random(in: 49152..<65535) + var connections = [NWConnection]() + + var listener: NWListener? + + private static func recursiveRead(_ connection: NWConnection) { + connection.receiveMessage { content, contentContext, _, error in + if let error { + print("Connection failed to receive message, error: \(error)") + return + } + + if let content, let contentContext { + connection.send(content: content, contentContext: contentContext, completion: .idempotent) + } + + recursiveRead(connection) + } + } + + func start() throws -> URL { + let params = NWParameters.tcp + let stack = params.defaultProtocolStack + let ws = NWProtocolWebSocket.Options(.version13) + stack.applicationProtocols.insert(ws, at: 0) + let port = NWEndpoint.Port(rawValue: portNumber)! + guard let listener = try? NWListener(using: params, on: port) else { + throw "unable to start the listener at: localhost:\(port)" + } + + listener.newConnectionHandler = { [weak self] conn in + self?.connections.append(conn) + conn.stateUpdateHandler = { state in + switch state { + case .ready: + print("Connection is ready") + case .setup: + print("Connection is setup") + case .preparing: + print("Connection is preparing") + case .waiting(let error): + print("Connection is waiting with error: \(error)") + case .failed(let error): + print("Connection failed with error \(error)") + case .cancelled: + print("Connection is cancelled") + @unknown default: + print("Connection is in unknown state -> \(state)") + } + } + conn.start(queue: DispatchQueue.global(qos: .userInitiated)) + Self.recursiveRead(conn) + } + + listener.stateUpdateHandler = { state in + switch state { + case .ready: + print("Socket is ready") + case .setup: + print("Socket is setup") + case .cancelled: + print("Socket is cancelled") + case .failed(let error): + print("Socket failed with error: \(error)") + case .waiting(let error): + print("Socket in waiting state with error: \(error)") + @unknown default: + print("Socket in unkown state -> \(state)") + break + } + } + + listener.start(queue: DispatchQueue.global(qos: .userInitiated)) + self.listener = listener + return URL(string: "http://localhost:\(portNumber)")! + } + + func stop() { + self.listener?.cancel() + } + + func sendTransientFailureToConnections() { + self.connections.forEach { + var metadata = NWProtocolWebSocket.Metadata(opcode: .close) + metadata.closeCode = .protocolCode(NWProtocolWebSocket.CloseCode.Defined.internalServerError) + $0.send( + content: nil, + contentContext: NWConnection.ContentContext(identifier: "WebSocket", metadata: [metadata]), + completion: .idempotent + ) + } + } +} diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/RetryWithJitterTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/RetryWithJitterTests.swift new file mode 100644 index 0000000000..9ada954056 --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/RetryWithJitterTests.swift @@ -0,0 +1,75 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +@testable @_spi(WebSocket) import AWSPluginsCore + +class RetryWithJitterTests: XCTestCase { + struct TestError: Error { + let message: String + } + + func testNext_returnDistinctValues() async { + let retryWithJitter = RetryWithJitter() + var values = Set() + for _ in 0..<20 { + values.insert(await retryWithJitter.next()) + } + XCTAssert(values.count > 10) + } + + func testNext_doNotBreachMaxCap() async { + let max: UInt = 100_000 + let retryWithJitter = RetryWithJitter(max: max) + var values = Set() + for _ in 0..<50 { + values.insert(await retryWithJitter.next()) + } + XCTAssert(values.allSatisfy { $0 < max}) + } + + func testExecute_operationFailed_retryToMaxRetryCount() async { + let maxRetryCount = 3 + let retryAttempts = expectation(description: "Total retry attempts") + retryAttempts.expectedFulfillmentCount = maxRetryCount + let failedWithExceedMaxRetryCountError = + expectation(description: "Execute should be failed with exceedMaxRetryCount error") + do { + try await RetryWithJitter.execute(maxRetryCount: UInt(maxRetryCount)) { + retryAttempts.fulfill() + throw TestError(message: "Failed operation") + } + } catch { + XCTAssert(error is RetryWithJitter.Error) + if case .maxRetryExceeded(let errors) = (error as! RetryWithJitter.Error) { + XCTAssertEqual(errors.count, maxRetryCount) + XCTAssert(errors.reduce(true) { + $0 && (($1 as? TestError).map { $0.message.contains("Failed operation") } == true) + } ) + failedWithExceedMaxRetryCountError.fulfill() + } + } + await fulfillment(of: [retryAttempts, failedWithExceedMaxRetryCountError], timeout: 5) + } + + func testExecute_operationSucceeded_noRetryObserved() async { + let maxRetryCount = 3 + let retryAttempts = expectation(description: "Total retry attempts") + retryAttempts.isInverted = true + let succeedExpectation = + expectation(description: "Execute should be succeed") + do { + try await RetryWithJitter.execute(maxRetryCount: UInt(maxRetryCount)) { + succeedExpectation.fulfill() + } + } catch { + XCTFail("No error expected") + } + await fulfillment(of: [retryAttempts, succeedExpectation], timeout: 1) + } +} diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/WebSocketClientTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/WebSocketClientTests.swift new file mode 100644 index 0000000000..f3e53669c1 --- /dev/null +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/WebSocket/WebSocketClientTests.swift @@ -0,0 +1,213 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +import Combine +@testable @_spi(WebSocket) import AWSPluginsCore + +fileprivate let timeout: TimeInterval = 5 + +class WebSocketClientTests: XCTestCase { + var localWebSocketServer: LocalWebSocketServer? + + override func setUp() async throws { + localWebSocketServer = LocalWebSocketServer() + } + + override func tearDown() async throws { + localWebSocketServer?.stop() + } + + func testConnect_withHttpScheme_didConnectedWithWs() async throws { + guard let endpoint = try localWebSocketServer?.start() else { + XCTFail("Local WebSocket server failed to start") + return + } + let webSocketClient = WebSocketClient(url: endpoint) + await verifyConnected(webSocketClient) + } + + func testDisconnect_didDisconnectFromRemote() async throws { + var cancellables = Set() + guard let endpoint = try localWebSocketServer?.start() else { + XCTFail("Local WebSocket server failed to start") + return + } + + let disconnectedExpectation = expectation(description: "WebSocket did disconnect") + + let webSocketClient = WebSocketClient(url: endpoint) + await verifyConnected(webSocketClient) + + await webSocketClient.publisher + .sink { event in + switch event { + case let .disconnected(closeCode, reason): + XCTAssertNil(reason) + XCTAssertEqual(closeCode, .goingAway) + disconnectedExpectation.fulfill() + default: + XCTFail("No other type of event should be received") + } + } + .store(in: &cancellables) + await webSocketClient.disconnect() + await fulfillment(of: [disconnectedExpectation], timeout: timeout) + } + + func testWriteAndRead_withWebSocketClient_didBehavesCorrectly() async throws { + var cancellables = Set() + guard let endpoint = try localWebSocketServer?.start() else { + XCTFail("Local WebSocket server failed to start") + return + } + + let messageReceivedExpectation = expectation(description: "WebSocket could read/write text message") + let dataReceivedExpectation = expectation(description: "WebSocket could read/wirte binary message") + let sampleMessage = UUID().uuidString + let sampleDataMessage = UUID().uuidString + + let webSocketClient = WebSocketClient(url: endpoint) + await verifyConnected(webSocketClient) + await webSocketClient.publisher.sink { event in + switch event { + case .string(let message) where message == sampleMessage: + messageReceivedExpectation.fulfill() + case .data(let data): + XCTAssertEqual(sampleDataMessage.hexaData, data) + dataReceivedExpectation.fulfill() + default: + XCTFail("No other type of event should be received") + } + }.store(in: &cancellables) + + try await webSocketClient.write(message: sampleMessage) + try await webSocketClient.write(message: sampleDataMessage.hexaData) + await fulfillment(of: [ + messageReceivedExpectation, + dataReceivedExpectation + ], timeout: timeout, enforceOrder: true) + } + + func testWebSocketClient_whenNetworkStateChagnes_disconnectOrReconnect() async throws { + var cancellables = Set() + guard let endpoint = try localWebSocketServer?.start() else { + XCTFail("Local WebSocket server failed to start") + return + } + + let mockNetworkMonitor = MockNetworkMonitor() + let webSocketClient = WebSocketClient(url: endpoint, networkMonitor: mockNetworkMonitor) + await verifyConnected(webSocketClient, autoConnectOnNetworkStatusChange: true) + + let disconnectExpectation = expectation(description: "Network drop should trigger disconnect") + await webSocketClient.publisher.sink { event in + switch event { + case let .disconnected(closeCode, reason): + XCTAssertEqual(closeCode, .invalid) + XCTAssertNil(reason) + disconnectExpectation.fulfill() + case let .error(error): + XCTAssertEqual(error as? WebSocketClient.Error, WebSocketClient.Error.connectionCancelled) + default: + XCTFail("No other type of event should be received") + } + } + .store(in: &cancellables) + // set network offline + await mockNetworkMonitor.updateState(.offline) + await fulfillment(of: [disconnectExpectation], timeout: timeout) + cancellables = Set() + + try await Task.sleep(seconds: 0.1) + let reconnectExpectation = expectation(description: "Network back online trigger reconnect") + await webSocketClient.publisher.sink { event in + switch event { + case .connected: + reconnectExpectation.fulfill() + default: + XCTFail("No other type of event should be received") + } + } + .store(in: &cancellables) + // set network online again + await mockNetworkMonitor.updateState(.online) + await fulfillment(of: [reconnectExpectation], timeout: timeout) + } + + func testAutoRetry_whenReceiveTransientFailureFromServer() async throws { + var cancellables = Set() + guard let endpoint = try localWebSocketServer?.start() else { + XCTFail("Local WebSocket server failed to start") + return + } + + let webSocketClient = WebSocketClient(url: endpoint) + await verifyConnected(webSocketClient, autoRetryOnConnectionFailure: true) + + let disconnectExpectation = expectation(description: "Tresient Server Error should trigger retry") + let reconnectedExpectation = expectation(description: "Connected should be re-triggered") + + await webSocketClient.publisher.sink { event in + switch event { + case let .disconnected(closeCode, reason): + XCTAssertEqual(closeCode, .internalServerError) + XCTAssert(reason == nil || reason!.isEmpty) + disconnectExpectation.fulfill() + case .connected: + reconnectedExpectation.fulfill() + default: + XCTFail("No other type of event should be received") + } + } + .store(in: &cancellables) + localWebSocketServer?.sendTransientFailureToConnections() + await fulfillment(of: [disconnectExpectation, reconnectedExpectation], timeout: timeout, enforceOrder: true) + } + + private func verifyConnected( + _ webSocketClient: WebSocketClient, + autoConnectOnNetworkStatusChange: Bool = false, + autoRetryOnConnectionFailure: Bool = false + ) async { + var cancellables = Set() + let connectedExpectation = expectation(description: "WebSocket did connect") + await webSocketClient.publisher.sink { event in + switch event { + case .connected: + connectedExpectation.fulfill() + default: + XCTFail("No other type of event should be received") + } + }.store(in: &cancellables) + + await webSocketClient.connect( + autoConnectOnNetworkStatusChange: autoConnectOnNetworkStatusChange, + autoRetryOnConnectionFailure: autoRetryOnConnectionFailure + ) + await fulfillment(of: [connectedExpectation], timeout: timeout) + } + +} + + +fileprivate class MockNetworkMonitor: WebSocketNetworkMonitorProtocol { + typealias State = AmplifyNetworkMonitor.State + let subject = PassthroughSubject() + var publisher: AnyPublisher<(State, State), Never> { + subject.scan((State.online, State.online)) { partial, newValue in + (partial.1, newValue) + }.eraseToAnyPublisher() + } + + func updateState(_ nextState: AmplifyNetworkMonitor.State) async { + subject.send(nextState) + } + + +} diff --git a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisherTests.swift b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisherTests.swift index 48100ba687..3adb5410c2 100644 --- a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisherTests.swift +++ b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/SubscriptionSync/IncomingAsyncSubscriptionEventPublisherTests.swift @@ -62,6 +62,7 @@ final class IncomingAsyncSubscriptionEventPublisherTests: XCTestCase { /// Ensure that the publisher-subscriber with back pressure is receiving all the events in the order in which they were sent. func testSubscriberRecievedEventsInOrder() async throws { + let prefix = UUID().uuidString let expectedEvents = expectation(description: "Expected number of ") let expectedOrder = AtomicValue<[String]>(initialValue: []) let actualOrder = AtomicValue<[String]>(initialValue: []) @@ -92,7 +93,7 @@ final class IncomingAsyncSubscriptionEventPublisherTests: XCTestCase { ) for index in 0..() + let expectation = expectation(description: "DataStore with 19 models should establish subscription in 2 seconds") + Amplify.Hub.publisher(for: .dataStore) + .filter { $0.eventName == HubPayload.EventName.DataStore.subscriptionsEstablished } + .sink { _ in expectation.fulfill() } + .store(in: &cancellables) + + Task { + try await Amplify.DataStore.start() + } + await fulfillment(of: [expectation], timeout: timeout) + withExtendedLifetime(cancellables, { }) + } + + private func stopDataStoreAndVerifyAppSyncClientDisconnected() async throws { + try await Amplify.DataStore.stop() + + guard let awsApiPlugin = try? Amplify.API.getPlugin(for: "awsAPIPlugin") as? AWSAPIPlugin else { + XCTFail("AWSAPIPlugin should not be nil") + return + } + + guard let appSyncRealTimeClientFactory = + awsApiPlugin.appSyncRealTimeClientFactory as? AppSyncRealTimeClientFactory + else { + XCTFail("AppSyncRealTimeClientFactory should not be nil") + return + } + + let appSyncRealTimeClients = (await appSyncRealTimeClientFactory.apiToClientCache.values) + .map { $0 as! AppSyncRealTimeClient } + + try await Task.sleep(seconds: 1) + + var allClientsDisconnected = true + for client in appSyncRealTimeClients { + let clientIsConnected = await client.isConnected + allClientsDisconnected = allClientsDisconnected && !clientIsConnected + } + XCTAssertTrue(allClientsDisconnected) + } +} diff --git a/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/AWSDataStorePluginIntegrationTests/DataStoreObserveQueryTests.swift b/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/AWSDataStorePluginIntegrationTests/DataStoreObserveQueryTests.swift index 7b40d7a51f..4e13e3e2ff 100644 --- a/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/AWSDataStorePluginIntegrationTests/DataStoreObserveQueryTests.swift +++ b/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/AWSDataStorePluginIntegrationTests/DataStoreObserveQueryTests.swift @@ -334,7 +334,7 @@ class DataStoreObserveQueryTests: SyncEngineIntegrationTestBase { /// - The final snapshot should have all the latest models with `isSynced` true /// func testObserveQuery_withClearedDataStore_fullySyncedWithMaxRecords() async throws { - await setUp(withModels: TestModelRegistration()) + await setUp(withModels: TestModelRegistration(), logLevel: .verbose) try await startAmplifyAndWaitForReady() try await clearDataStore() diff --git a/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/DataStoreHostApp.xcodeproj/project.pbxproj b/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/DataStoreHostApp.xcodeproj/project.pbxproj index 54af42878f..f6fd20640b 100644 --- a/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/DataStoreHostApp.xcodeproj/project.pbxproj +++ b/AmplifyPlugins/DataStore/Tests/DataStoreHostApp/DataStoreHostApp.xcodeproj/project.pbxproj @@ -565,6 +565,7 @@ 602E8BF92A37D13700A3EA1E /* AWSDataStoreAWSIPAddressSortKeyTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 602E8BF32A37D13700A3EA1E /* AWSDataStoreAWSIPAddressSortKeyTest.swift */; }; 602E8BFA2A37D13700A3EA1E /* AWSDataStoreAWSEmailSortKeyTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 602E8BF52A37D13700A3EA1E /* AWSDataStoreAWSEmailSortKeyTest.swift */; }; 602E8BFC2A37D14900A3EA1E /* AWSDataStoreSortKeyBaseTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = 602E8BFB2A37D14900A3EA1E /* AWSDataStoreSortKeyBaseTest.swift */; }; + 606C8B7B2B8FAFF700716094 /* DataStoreLargeNumberModelsSubscriptionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 606C8B7A2B8FAFF700716094 /* DataStoreLargeNumberModelsSubscriptionTests.swift */; }; 6080BE5F2A37D48A0086EBDF /* Post16.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6080BE5D2A37D48A0086EBDF /* Post16.swift */; }; 6080BE602A37D48A0086EBDF /* Post16+Schema.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6080BE5E2A37D48A0086EBDF /* Post16+Schema.swift */; }; 6080BE632A37D4940086EBDF /* Post17+Schema.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6080BE612A37D4940086EBDF /* Post17+Schema.swift */; }; @@ -2135,6 +2136,7 @@ 602E8BF32A37D13700A3EA1E /* AWSDataStoreAWSIPAddressSortKeyTest.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AWSDataStoreAWSIPAddressSortKeyTest.swift; sourceTree = ""; }; 602E8BF52A37D13700A3EA1E /* AWSDataStoreAWSEmailSortKeyTest.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AWSDataStoreAWSEmailSortKeyTest.swift; sourceTree = ""; }; 602E8BFB2A37D14900A3EA1E /* AWSDataStoreSortKeyBaseTest.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AWSDataStoreSortKeyBaseTest.swift; sourceTree = ""; }; + 606C8B7A2B8FAFF700716094 /* DataStoreLargeNumberModelsSubscriptionTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DataStoreLargeNumberModelsSubscriptionTests.swift; sourceTree = ""; }; 6080BE5D2A37D48A0086EBDF /* Post16.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Post16.swift; sourceTree = ""; }; 6080BE5E2A37D48A0086EBDF /* Post16+Schema.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Post16+Schema.swift"; sourceTree = ""; }; 6080BE612A37D4940086EBDF /* Post17+Schema.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Post17+Schema.swift"; sourceTree = ""; }; @@ -3088,6 +3090,7 @@ 21BBFC25289C054900B32A39 /* Models */, 21BBF9E7289BFE3400B32A39 /* TestSupport */, 21BBFA0D289BFE3400B32A39 /* README.md */, + 606C8B7A2B8FAFF700716094 /* DataStoreLargeNumberModelsSubscriptionTests.swift */, 21BBFA0C289BFE3400B32A39 /* SubscriptionEndToEndTests.swift */, 21BBFA0E289BFE3400B32A39 /* DataStoreScalarTests.swift */, 21BBFA15289BFE3400B32A39 /* DataStoreObserveQueryTests.swift */, @@ -5280,6 +5283,7 @@ 21BBFD62289C06E400B32A39 /* TodoWithDefaultValueV2.swift in Sources */, 21BBFD8E289C06E400B32A39 /* Nested.swift in Sources */, 21BBFDE2289C06E400B32A39 /* QPredGen+Schema.swift in Sources */, + 606C8B7B2B8FAFF700716094 /* DataStoreLargeNumberModelsSubscriptionTests.swift in Sources */, 21BBFDC2289C06E400B32A39 /* Blog6.swift in Sources */, 21BBFD4E289C06E400B32A39 /* Team4aV2.swift in Sources */, 21BBFDD8289C06E400B32A39 /* Team2+Schema.swift in Sources */, diff --git a/AmplifyPlugins/Logging/Sources/AWSCloudWatchLoggingPlugin/Resources/PrivacyInfo.xcprivacy b/AmplifyPlugins/Logging/Sources/AWSCloudWatchLoggingPlugin/Resources/PrivacyInfo.xcprivacy index 185f477f93..b3c2cb18e1 100644 --- a/AmplifyPlugins/Logging/Sources/AWSCloudWatchLoggingPlugin/Resources/PrivacyInfo.xcprivacy +++ b/AmplifyPlugins/Logging/Sources/AWSCloudWatchLoggingPlugin/Resources/PrivacyInfo.xcprivacy @@ -1,32 +1,40 @@ - - NSPrivacyCollectedDataTypes - - - NSPrivacyCollectedDataType - NSPrivacyCollectedDataTypeOtherDataTypes - NSPrivacyCollectedDataTypeLinked - - NSPrivacyCollectedDataTypeTracking - - NSPrivacyCollectedDataTypePurposes - - NSPrivacyCollectedDataTypePurposeAnalytics - - - - NSPrivacyAccessedAPITypes - - - NSPrivacyAccessedAPIType - NSPrivacyAccessedAPICategoryUserDefaults - NSPrivacyAccessedAPITypeReasons - - CA92.1 - - - - + + NSPrivacyCollectedDataTypes + + + NSPrivacyCollectedDataType + NSPrivacyCollectedDataTypeOtherDataTypes + NSPrivacyCollectedDataTypeLinked + + NSPrivacyCollectedDataTypeTracking + + NSPrivacyCollectedDataTypePurposes + + NSPrivacyCollectedDataTypePurposeAnalytics + + + + NSPrivacyAccessedAPITypes + + + NSPrivacyAccessedAPIType + NSPrivacyAccessedAPICategoryUserDefaults + NSPrivacyAccessedAPITypeReasons + + CA92.1 + + + + NSPrivacyAccessedAPIType + NSPrivacyAccessedAPICategoryFileTimestamp + NSPrivacyAccessedAPITypeReasons + + C617.1 + + + + diff --git a/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationHostApp.xcscheme b/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationHostApp.xcscheme new file mode 100644 index 0000000000..befc4a780c --- /dev/null +++ b/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationHostApp.xcscheme @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationHostAppUITests.xcscheme b/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationHostAppUITests.xcscheme new file mode 100644 index 0000000000..82b0044f1f --- /dev/null +++ b/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationHostAppUITests.xcscheme @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationsWatchApp.xcscheme b/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationsWatchApp.xcscheme new file mode 100644 index 0000000000..0f25f261e2 --- /dev/null +++ b/AmplifyPlugins/Notifications/Push/Tests/PushNotificationHostApp/PushNotificationHostApp.xcodeproj/xcshareddata/xcschemes/PushNotificationsWatchApp.xcscheme @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/AmplifyTests/CoreTests/JSONValueTests.swift b/AmplifyTests/CoreTests/JSONValueTests.swift index c719746893..e677d4e33b 100644 --- a/AmplifyTests/CoreTests/JSONValueTests.swift +++ b/AmplifyTests/CoreTests/JSONValueTests.swift @@ -90,4 +90,50 @@ class JSONValueTests: XCTestCase { let literalValue: JSONValue = ["foo": "bar"] XCTAssertEqual(enumValue, literalValue) } + + func testDynamicMemberLookup() { + let json = JSONValue.object(["foo": .object(["bar": 2])]) + XCTAssertEqual(json.foo?.bar?.intValue, 2) + } + + func testIntValue() { + let offset = 100000 + let badInt = JSONValue.number(Double(Int.max)) + XCTAssertNil(badInt.intValue) + let badInt2 = JSONValue.number(Double(Int.min) - Double(offset)) + XCTAssertNil(badInt2.intValue) + let goodInt = JSONValue.number(Double(100)) + XCTAssertEqual(goodInt.intValue, 100) + } + + func testDoubleValue() { + let double = 1000.0 + XCTAssertEqual(JSONValue.number(double).doubleValue, double) + } + + func testStringValue() { + let str = UUID().uuidString + XCTAssertEqual(JSONValue.string(str).stringValue, str) + } + + func testBooleanValue() { + let bool = false + XCTAssertEqual(JSONValue.boolean(bool).booleanValue, bool) + } + + func testObjectValue() { + let obj: [String: JSONValue] = [ + "a": "a", + "b": 0, + "c": false + ] + + XCTAssertEqual(JSONValue.object(obj).asObject, obj) + } + + func testArrayValue() { + let arr: [JSONValue] = ["a", 0, false] + XCTAssertEqual(JSONValue.array(arr).asArray, arr) + } + } diff --git a/CHANGELOG.md b/CHANGELOG.md index 111e7fec66..fb152a5b6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 2.27.3 (2024-03-18) + +### Bug Fixes + +- **Logging**: Updating the required reason API usage (#3570) + ## 2.27.2 (2024-03-11) ### Bug Fixes diff --git a/Package.resolved b/Package.resolved index 433497e13e..5be60bbe90 100644 --- a/Package.resolved +++ b/Package.resolved @@ -9,15 +9,6 @@ "version" : "1.1.1" } }, - { - "identity" : "aws-appsync-realtime-client-ios", - "kind" : "remoteSourceControl", - "location" : "https://github.com/aws-amplify/aws-appsync-realtime-client-ios.git", - "state" : { - "revision" : "a08684c5004e2049c29f57a5938beae9695a1ef7", - "version" : "3.1.2" - } - }, { "identity" : "aws-crt-swift", "kind" : "remoteSourceControl", @@ -72,15 +63,6 @@ "version" : "0.13.2" } }, - { - "identity" : "starscream", - "kind" : "remoteSourceControl", - "location" : "https://github.com/daltoniam/Starscream", - "state" : { - "revision" : "df8d82047f6654d8e4b655d1b1525c64e1059d21", - "version" : "4.0.4" - } - }, { "identity" : "swift-log", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 43065db332..77f5211f2a 100644 --- a/Package.swift +++ b/Package.swift @@ -10,7 +10,6 @@ let platforms: [SupportedPlatform] = [ ] let dependencies: [Package.Dependency] = [ .package(url: "https://github.com/awslabs/aws-sdk-swift.git", exact: "0.36.1"), - .package(url: "https://github.com/aws-amplify/aws-appsync-realtime-client-ios.git", from: "3.0.0"), .package(url: "https://github.com/stephencelis/SQLite.swift.git", exact: "0.13.2"), .package(url: "https://github.com/mattgallagher/CwlPreconditionTesting.git", from: "2.1.0"), .package(url: "https://github.com/aws-amplify/amplify-swift-utils-notifications.git", from: "1.1.0") @@ -116,8 +115,8 @@ let apiTargets: [Target] = [ name: "AWSAPIPlugin", dependencies: [ .target(name: "Amplify"), - .target(name: "AWSPluginsCore"), - .product(name: "AppSyncRealTimeClient", package: "aws-appsync-realtime-client-ios")], + .target(name: "AWSPluginsCore") + ], path: "AmplifyPlugins/API/Sources/AWSAPIPlugin", exclude: [ "Info.plist",