diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/CredentialStorage/AWSCognitoAuthCredentialStore.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/CredentialStorage/AWSCognitoAuthCredentialStore.swift index 1e8b1f77f5..dd55162b57 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/CredentialStorage/AWSCognitoAuthCredentialStore.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/CredentialStorage/AWSCognitoAuthCredentialStore.swift @@ -133,8 +133,8 @@ extension AWSCognitoAuthCredentialStore: AmplifyAuthCredentialStoreBehavior { func retrieveCredential() throws -> AmplifyCredentials { let authCredentialStoreKey = generateSessionKey(for: authConfiguration) let authCredentialData = try keychain._getData(authCredentialStoreKey) - let awsCredential: AmplifyCredentials = try decode(data: authCredentialData) - return awsCredential + let amplifyCredential: AmplifyCredentials = try decode(data: authCredentialData) + return amplifyCredential } func deleteCredential() throws { @@ -191,7 +191,7 @@ private extension AWSCognitoAuthCredentialStore { do { return try JSONEncoder().encode(object) } catch { - throw KeychainStoreError.codingError("Error occurred while encoding AWSCredentials", error) + throw KeychainStoreError.codingError("Error occurred while encoding credentials", error) } } @@ -199,7 +199,7 @@ private extension AWSCognitoAuthCredentialStore { do { return try JSONDecoder().decode(T.self, from: data) } catch { - throw KeychainStoreError.codingError("Error occurred while decoding AWSCredentials", error) + throw KeychainStoreError.codingError("Error occurred while decoding credentials", error) } } diff --git a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AuthFlowType.swift b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AuthFlowType.swift index 46b701bcd3..b352e873c5 100644 --- a/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AuthFlowType.swift +++ b/AmplifyPlugins/Auth/Sources/AWSCognitoAuthPlugin/Models/AuthFlowType.swift @@ -36,8 +36,10 @@ public enum AuthFlowType { internal init?(rawValue: String) { switch rawValue { - case "CUSTOM_AUTH": + case "CUSTOM_AUTH", "CUSTOM_AUTH_WITH_SRP": self = .customWithSRP + case "CUSTOM_AUTH_WITHOUT_SRP": + self = .customWithoutSRP case "USER_SRP_AUTH": self = .userSRP case "USER_PASSWORD_AUTH": @@ -51,8 +53,10 @@ public enum AuthFlowType { var rawValue: String { switch self { - case .custom, .customWithSRP, .customWithoutSRP: - return "CUSTOM_AUTH" + case .custom, .customWithSRP: + return "CUSTOM_AUTH_WITH_SRP" + case .customWithoutSRP: + return "CUSTOM_AUTH_WITHOUT_SRP" case .userSRP: return "USER_SRP_AUTH" case .userPassword: @@ -62,6 +66,24 @@ public enum AuthFlowType { } } + // This initializer has been added to migrate credentials that were created in the pre-passwordless era + internal static func legacyInit(rawValue: String) -> Self? { + switch rawValue { + case "userSRP": + return .userSRP + case "userPassword": + return .userPassword + case "custom": + return .custom + case "customWithSRP": + return .customWithSRP + case "customWithoutSRP": + return .customWithoutSRP + default: + return nil + } + } + public static var userAuth: AuthFlowType { return .userAuth(preferredFirstFactor: nil) } @@ -110,27 +132,49 @@ extension AuthFlowType: Codable { // Decoding the enum public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) + let container: KeyedDecodingContainer + do { + container = try decoder.container(keyedBy: CodingKeys.self) + } catch DecodingError.typeMismatch { + // The type mismatch has been added to handle a scenario where the user is migrating passwordless flows. + // Passwordless flow added a new enum case with a associated type. + // The association resulted in encoding structure changes that is different from the non-passwordless flows. + // The structure change causes the type mismatch exception and this code block tries to retrieve the legacy structure and decode it. + let legacyContainer = try decoder.singleValueContainer() + let type = try legacyContainer.decode(String.self) + guard let authFlowType = AuthFlowType.legacyInit(rawValue: type) else { + throw DecodingError.dataCorruptedError(in: legacyContainer, debugDescription: "Invalid AuthFlowType value") + } + self = authFlowType + return + } catch { + throw error + } - // Decode the type (raw value) let type = try container.decode(String.self, forKey: .type) // Initialize based on the type switch type { case "USER_SRP_AUTH": self = .userSRP - case "CUSTOM_AUTH": - // Depending on your needs, choose either `.custom`, `.customWithSRP`, or `.customWithoutSRP` - // In this case, we'll default to `.custom` - self = .custom + case "CUSTOM_AUTH", "CUSTOM_AUTH_WITH_SRP": + self = .customWithSRP + case "CUSTOM_AUTH_WITHOUT_SRP": + self = .customWithoutSRP case "USER_PASSWORD_AUTH": self = .userPassword case "USER_AUTH": - let preferredFirstFactorString = try container.decode(String.self, forKey: .preferredFirstFactor) - if let preferredFirstFactor = AuthFactorType(rawValue: preferredFirstFactorString) { - self = .userAuth(preferredFirstFactor: preferredFirstFactor) + if let preferredFirstFactorString = try container.decodeIfPresent(String.self, forKey: .preferredFirstFactor) { + if let preferredFirstFactor = AuthFactorType(rawValue: preferredFirstFactorString) { + self = .userAuth(preferredFirstFactor: preferredFirstFactor) + } else { + throw DecodingError.dataCorruptedError( + forKey: .preferredFirstFactor, + in: container, + debugDescription: "Unable to decode preferredFirstFactor value") + } } else { - throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Unable to decode preferredFirstFactor value") + self = .userAuth(preferredFirstFactor: nil) } default: throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Invalid AuthFlowType value") @@ -152,5 +196,4 @@ extension AuthFlowType { return .userAuth } } - } diff --git a/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/AuthFlowTypeTests.swift b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/AuthFlowTypeTests.swift new file mode 100644 index 0000000000..c56c02f0cf --- /dev/null +++ b/AmplifyPlugins/Auth/Tests/AWSCognitoAuthPluginUnitTests/ConfigurationTests/AuthFlowTypeTests.swift @@ -0,0 +1,144 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +@testable import AWSCognitoAuthPlugin + +class AuthFlowTypeTests: XCTestCase { + + func testRawValue() { + XCTAssertEqual(AuthFlowType.userSRP.rawValue, "USER_SRP_AUTH") + XCTAssertEqual(AuthFlowType.customWithSRP.rawValue, "CUSTOM_AUTH_WITH_SRP") + XCTAssertEqual(AuthFlowType.customWithoutSRP.rawValue, "CUSTOM_AUTH_WITHOUT_SRP") + XCTAssertEqual(AuthFlowType.userPassword.rawValue, "USER_PASSWORD_AUTH") + XCTAssertEqual(AuthFlowType.userAuth(preferredFirstFactor: nil).rawValue, "USER_AUTH") + } + + func testInitWithRawValue() { + XCTAssertEqual(AuthFlowType(rawValue: "USER_SRP_AUTH"), .userSRP) + XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH"), .customWithSRP) + XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH_WITH_SRP"), .customWithSRP) + XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH_WITHOUT_SRP"), .customWithoutSRP) + XCTAssertEqual(AuthFlowType(rawValue: "USER_PASSWORD_AUTH"), .userPassword) + XCTAssertEqual(AuthFlowType(rawValue: "USER_AUTH"), .userAuth(preferredFirstFactor: nil)) + XCTAssertNil(AuthFlowType(rawValue: "INVALID_AUTH")) + } + + func testDeprecatedCustom() { + // This test is to ensure the deprecated case is still functional + XCTAssertEqual(AuthFlowType.custom.rawValue, "CUSTOM_AUTH_WITH_SRP") + } + + func testEncoding() throws { + let encoder = JSONEncoder() + let userSRP = try encoder.encode(AuthFlowType.userSRP) + XCTAssertEqual(String(data: userSRP, encoding: .utf8), "{\"type\":\"USER_SRP_AUTH\"}") + + let customWithSRP = try encoder.encode(AuthFlowType.customWithSRP) + XCTAssertEqual(String(data: customWithSRP, encoding: .utf8), "{\"type\":\"CUSTOM_AUTH_WITH_SRP\"}") + + let customWithoutSRP = try encoder.encode(AuthFlowType.customWithoutSRP) + XCTAssertEqual(String(data: customWithoutSRP, encoding: .utf8), "{\"type\":\"CUSTOM_AUTH_WITHOUT_SRP\"}") + + let userPassword = try encoder.encode(AuthFlowType.userPassword) + XCTAssertEqual(String(data: userPassword, encoding: .utf8), "{\"type\":\"USER_PASSWORD_AUTH\"}") + + let userAuth = try encoder.encode(AuthFlowType.userAuth(preferredFirstFactor: nil)) + XCTAssertTrue(String(data: userAuth, encoding: .utf8)?.contains("\"preferredFirstFactor\":null") == true) + XCTAssertTrue(String(data: userAuth, encoding: .utf8)?.contains("\"type\":\"USER_AUTH\"") == true) + } + + func testDecoding() throws { + let decoder = JSONDecoder() + let userSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_SRP_AUTH\"}".data(using: .utf8)!) + XCTAssertEqual(userSRP, .userSRP) + + let customWithSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"CUSTOM_AUTH_WITH_SRP\"}".data(using: .utf8)!) + XCTAssertEqual(customWithSRP, .customWithSRP) + + let customWithoutSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"CUSTOM_AUTH_WITHOUT_SRP\"}".data(using: .utf8)!) + XCTAssertEqual(customWithoutSRP, .customWithoutSRP) + + let userPassword = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_PASSWORD_AUTH\"}".data(using: .utf8)!) + XCTAssertEqual(userPassword, .userPassword) + + let userAuth = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_AUTH\"}".data(using: .utf8)!) + XCTAssertEqual(userAuth, .userAuth(preferredFirstFactor: nil)) + } + + func testDecodingWithPreferredFirstFactor() throws { + let decoder = JSONDecoder() + let json = """ + { + "type": "USER_AUTH", + "preferredFirstFactor": "SMS_OTP" + } + """.data(using: .utf8)! + let authFlowType = try decoder.decode(AuthFlowType.self, from: json) + XCTAssertEqual(authFlowType, .userAuth(preferredFirstFactor: .smsOTP)) + } + + func testDecodingLegacyStructure() throws { + let decoder = JSONDecoder() + var legacyJson = "\"userSRP\"".data(using: .utf8)! + var authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson) + XCTAssertEqual(authFlowType, .userSRP) + + legacyJson = "\"userPassword\"".data(using: .utf8)! + authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson) + XCTAssertEqual(authFlowType, .userPassword) + + legacyJson = "\"customWithSRP\"".data(using: .utf8)! + authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson) + XCTAssertEqual(authFlowType, .customWithSRP) + + legacyJson = "\"customWithoutSRP\"".data(using: .utf8)! + authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson) + XCTAssertEqual(authFlowType, .customWithoutSRP) + + legacyJson = "\"custom\"".data(using: .utf8)! + authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson) + XCTAssertEqual(authFlowType, .custom) + } + + func testDecodingInvalidType() { + let decoder = JSONDecoder() + let invalidJson = "{\"type\":\"INVALID_AUTH\"}".data(using: .utf8)! + XCTAssertThrowsError(try decoder.decode(AuthFlowType.self, from: invalidJson)) { error in + guard case DecodingError.dataCorrupted(let context) = error else { + return XCTFail("Expected dataCorrupted error") + } + XCTAssertEqual(context.debugDescription, "Invalid AuthFlowType value") + } + } + + func testDecodingInvalidPreferredFirstFactor() { + let decoder = JSONDecoder() + let invalidJson = """ + { + "type": "USER_AUTH", + "preferredFirstFactor": "INVALID_FACTOR" + } + """.data(using: .utf8)! + XCTAssertThrowsError(try decoder.decode(AuthFlowType.self, from: invalidJson)) { error in + guard case DecodingError.dataCorrupted(let context) = error else { + return XCTFail("Expected dataCorrupted error") + } + XCTAssertEqual(context.debugDescription, "Unable to decode preferredFirstFactor value") + } + } + + func testGetClientFlowType() { + XCTAssertEqual(AuthFlowType.custom.getClientFlowType(), .customAuth) + XCTAssertEqual(AuthFlowType.customWithSRP.getClientFlowType(), .customAuth) + XCTAssertEqual(AuthFlowType.customWithoutSRP.getClientFlowType(), .customAuth) + XCTAssertEqual(AuthFlowType.userSRP.getClientFlowType(), .userSrpAuth) + XCTAssertEqual(AuthFlowType.userPassword.getClientFlowType(), .userPasswordAuth) + XCTAssertEqual(AuthFlowType.userAuth(preferredFirstFactor: nil).getClientFlowType(), .userAuth) + } +}