From 11d94ef3fc81babfa650d6d4a5d35152c7d61c4f Mon Sep 17 00:00:00 2001 From: Karen <64801825+karenc-bq@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:50:30 -0700 Subject: [PATCH] feat: adfs authentication support (#214) --- .github/workflows/build-installer.yml | 7 + driver/adfs_proxy.cc | 225 ++++++++++++++++++++++++-- driver/adfs_proxy.h | 69 +++++--- driver/auth_util.cc | 50 +++++- driver/auth_util.h | 8 +- driver/iam_proxy.cc | 61 +------ driver/iam_proxy.h | 7 - driver/okta_proxy.cc | 30 ++-- driver/okta_proxy.h | 2 +- driver/saml_http_client.cc | 61 +++++-- driver/saml_http_client.h | 4 +- unit_testing/CMakeLists.txt | 1 + unit_testing/adfs_proxy_test.cc | 118 ++++++++++++++ unit_testing/iam_proxy_test.cc | 109 ++++++------- unit_testing/mock_objects.h | 9 +- unit_testing/okta_proxy_test.cc | 7 +- unit_testing/test_utils.cc | 8 +- unit_testing/test_utils.h | 3 +- 18 files changed, 586 insertions(+), 193 deletions(-) create mode 100644 unit_testing/adfs_proxy_test.cc diff --git a/.github/workflows/build-installer.yml b/.github/workflows/build-installer.yml index 1b7c95ac3..10ac9aa93 100644 --- a/.github/workflows/build-installer.yml +++ b/.github/workflows/build-installer.yml @@ -22,6 +22,13 @@ jobs: run: | curl -L https://dev.mysql.com/get/Downloads/MySQL-8.3/mysql-${{ vars.MYSQL_VERSION }}-winx64.zip -o mysql.zip unzip -d C:/ mysql.zip + + - name: Install OpenSSL 3 + run: | + curl -L https://download.firedaemon.com/FireDaemon-OpenSSL/openssl-3.3.1.zip -o openssl3.zip + unzip -d C:/ openssl3.zip + cp -r C:/openssl-3/x64/bin/libssl-3-x64.dll C:/Windows/System32/ + cp -r C:/openssl-3/x64/bin/libcrypto-3-x64.dll C:/Windows/System32/ - name: Add msbuild to PATH uses: microsoft/setup-msbuild@v2 diff --git a/driver/adfs_proxy.cc b/driver/adfs_proxy.cc index 608caed66..2a94bf04c 100644 --- a/driver/adfs_proxy.cc +++ b/driver/adfs_proxy.cc @@ -28,25 +28,186 @@ // http://www.gnu.org/licenses/gpl-2.0.html. #include "adfs_proxy.h" +#include #include "driver.h" +#define SIGN_IN_PAGE_URL "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=urn:amazon:webservices" + +std::unordered_map ADFS_PROXY::token_cache; +std::mutex ADFS_PROXY::token_cache_mutex; + ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds) : ADFS_PROXY(dbc, ds, nullptr) {}; ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) { this->next_proxy = next_proxy; - if (ds->opt_AUTH_REGION) { - this->auth_util = std::make_shared((const char*)ds->opt_AUTH_REGION); + std::string host{static_cast(ds->opt_IDP_ENDPOINT)}; + host += ":" + std::to_string(ds->opt_IDP_PORT); + + const int client_connect_timeout = ds->opt_CLIENT_CONNECT_TIMEOUT; + const int client_socket_timeout = ds->opt_CLIENT_SOCKET_TIMEOUT; + const bool enable_ssl = ds->opt_ENABLE_SSL; + this->saml_util = std::make_shared(host, client_connect_timeout, client_socket_timeout, enable_ssl); +} + +void ADFS_PROXY::clear_token_cache() { + std::unique_lock lock(token_cache_mutex); + token_cache.clear(); +} + +ADFS_SAML_UTIL::ADFS_SAML_UTIL(const std::shared_ptr& client) { this->http_client = client; } + +ADFS_SAML_UTIL::ADFS_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) { + this->http_client = + std::make_shared("https://" + host, connect_timeout, socket_timeout, enable_ssl); +} + +std::string ADFS_SAML_UTIL::get_saml_assertion(DataSource* ds) { + nlohmann::json res; + try { + res = this->http_client->get(std::string(SIGN_IN_PAGE_URL)); + } catch (SAML_HTTP_EXCEPTION& e) { + const std::string error = + "Failed to get sign-in page from ADFS: " + e.error_message() + ". Please verify your IDP endpoint."; + throw SAML_HTTP_EXCEPTION(error); + } + + const auto body = std::string(res); + std::smatch m; + if (!std::regex_search(body, m, ADFS_REGEX::FORM_ACTION_PATTERN)) { + return std::string(); } - else { - this->auth_util = std::make_shared(); + std::string form_action = unescape_html_entity(m.str(1)); + const std::string params = get_parameters_from_html(ds, body); + const std::string content = get_form_action_body(form_action, params); + if (std::regex_search(content, m, ADFS_REGEX::SAML_RESPONSE_PATTERN)) { + return m.str(1); } + return std::string(); +} + +std::string ADFS_SAML_UTIL::unescape_html_entity(const std::string& html) { + std::string retval(""); + int i = 0; + int length = html.length(); + while (i < length) { + char c = html[i]; + if (c != '&') { + retval.append(1, c); + i++; + continue; + } + + if (html.substr(i, 4) == "<") { + retval.append(1, '<'); + i += 4; + } else if (html.substr(i, 4) == ">") { + retval.append(1, '>'); + i += 4; + } else if (html.substr(i, 5) == "&") { + retval.append(1, '&'); + i += 5; + } else if (html.substr(i, 6) == "'") { + retval.append(1, '\''); + i += 6; + } else if (html.substr(i, 6) == """) { + retval.append(1, '"'); + i += 6; + } else { + retval.append(1, c); + ++i; + } + } + return retval; +} + +std::vector ADFS_SAML_UTIL::get_input_tags_from_html(const std::string& body) { + std::unordered_set hashSet; + std::vector retval; + + std::smatch matches; + std::regex pattern(ADFS_REGEX::INPUT_TAG_PATTERN); + std::string source = body; + while (std::regex_search(source, matches, pattern)) { + std::string tag = matches.str(0); + std::string tagName = get_value_by_key(tag, std::string("name")); + std::transform(tagName.begin(), tagName.end(), tagName.begin(), [](unsigned char c) { return std::tolower(c); }); + if (!tagName.empty() && hashSet.find(tagName) == hashSet.end()) { + hashSet.insert(tagName); + retval.push_back(tag); + } + + source = matches.suffix().str(); + } + + return retval; +} + +std::string ADFS_SAML_UTIL::get_value_by_key(const std::string& input, const std::string& key) { + std::string pattern("("); + pattern += key; + pattern += ")\\s*=\\s*\"(.*?)\""; + + std::smatch matches; + if (std::regex_search(input, matches, std::regex(pattern))) { + MYLOG_TRACE(init_log_file(), 0, "get_value_by_key"); + return unescape_html_entity(matches.str(2)); + } + return ""; +} + +std::string ADFS_SAML_UTIL::get_parameters_from_html(DataSource* ds, const std::string& body) { + std::map parameters; + for (auto& inputTag : get_input_tags_from_html(body)) { + std::string name = get_value_by_key(inputTag, std::string("name")); + std::string value = get_value_by_key(inputTag, std::string("value")); + std::string nameLower = name; + std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), + [](unsigned char c) { return std::tolower(c); }); + + const std::string username = static_cast(ds->opt_IDP_USERNAME); + const std::string password = static_cast(ds->opt_IDP_PASSWORD); + + if (nameLower.find("username") != std::string::npos) { + parameters.insert(std::pair(name, username)); + } else if ((nameLower.find("authmethod") != std::string::npos) && !value.empty()) { + parameters.insert(std::pair(name, value)); + } else if (nameLower.find("password") != std::string::npos) { + parameters.insert(std::pair(name, password)); + } else if (!name.empty()) { + parameters.insert(std::pair(name, value)); + } + } + + // Convert parameters to a & delimited string, e.g. username=u&password=p + const std::string delimiter = "&"; + const std::string result = + std::accumulate(parameters.begin(), parameters.end(), std::string(), + [delimiter](const std::string& s, const std::pair& p) { + return s + (s.empty() ? std::string() : delimiter) + p.first + "=" + p.second; + }); + + return result; +} + +std::string ADFS_SAML_UTIL::get_form_action_body(const std::string& url, const std::string& params) { + nlohmann::json res; + try { + res = this->http_client->post(url, params, "application/x-www-form-urlencoded"); + } catch (SAML_HTTP_EXCEPTION& e) { + const std::string error = + "Failed to get SAML Assertion from ADFS : " + e.error_message() + ". Please verify your ADFS credentials."; + throw SAML_HTTP_EXCEPTION(error); + } + return res.empty() ? "" : res; } #ifdef UNIT_TEST_BUILD -ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, - std::shared_ptr auth_util) : CONNECTION_PROXY(dbc, ds) { - this->next_proxy = next_proxy; - this->auth_util = auth_util; +ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr auth_util, + const std::shared_ptr& client) + : CONNECTION_PROXY(dbc, ds) { + this->next_proxy = next_proxy; + this->auth_util = auth_util; + this->saml_util = std::make_shared(client); } #endif @@ -54,5 +215,51 @@ ADFS_PROXY::~ADFS_PROXY() { this->auth_util.reset(); } bool ADFS_PROXY::connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, const char* socket, unsigned long flags) { - return true; + auto func = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port, + socket, flags); + const char* region = + ds->opt_FED_AUTH_REGION ? static_cast(ds->opt_FED_AUTH_REGION) : Aws::Region::US_EAST_1; + std::string assertion; + try { + assertion = this->saml_util->get_saml_assertion(ds); + } catch (SAML_HTTP_EXCEPTION& e) { + this->set_custom_error_message(e.error_message().c_str()); + return false; + } + + auto idp_host = static_cast(ds->opt_IDP_ENDPOINT); + auto iam_role_arn = static_cast(ds->opt_IAM_ROLE_ARN); + auto idp_arn = static_cast(ds->opt_IAM_IDP_ARN); + const Aws::Auth::AWSCredentials credentials = + this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion); + this->auth_util = std::make_shared(region, credentials); + + const char* auth_host = ds->opt_FED_AUTH_HOST ? static_cast(ds->opt_FED_AUTH_HOST) + : static_cast(ds->opt_SERVER); + const int auth_port = ds->opt_FED_AUTH_PORT; + + std::string auth_token; + bool using_cached_token; + std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token( + token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION); + + bool connect_result = func(auth_token.c_str()); + if (!connect_result) { + if (using_cached_token) { + // Retry func with a fresh token + std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token( + token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION, true); + if (func(auth_token.c_str())) { + return true; + } + } + + if (credentials.IsEmpty()) { + this->set_custom_error_message( + "Unable to generate temporary AWS credentials from the SAML assertion. Please ensure the ADFS identity " + "provider is correctly configured with AWS."); + } + } + + return connect_result; } diff --git a/driver/adfs_proxy.h b/driver/adfs_proxy.h index fb1ad4e65..e3bd96b11 100644 --- a/driver/adfs_proxy.h +++ b/driver/adfs_proxy.h @@ -30,40 +30,61 @@ #ifndef __ADFS_PROXY__ #define __ADFS_PROXY__ +#include #include #include "auth_util.h" +#include "saml_http_client.h" +#include "saml_util.h" + +namespace ADFS_REGEX { + const std::regex FORM_ACTION_PATTERN(R"#()", std::regex_constants::icase); + const std::regex URL_PATTERN(R"#(^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_'])#", + std::regex_constants::icase); + const std::regex INPUT_TAG_PATTERN(R"#(& client); + ADFS_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl); + std::string get_saml_assertion(DataSource* ds) override; + std::shared_ptr http_client; + + private: + static std::string unescape_html_entity(const std::string& html); + std::vector get_input_tags_from_html(const std::string& body); + std::string get_value_by_key(const std::string& input, const std::string& key); + std::string get_parameters_from_html(DataSource* ds, const std::string& body); + std::string get_form_action_body(const std::string& url, const std::string& params); +}; class ADFS_PROXY : public CONNECTION_PROXY { -public: - ADFS_PROXY() = default; - ADFS_PROXY(DBC* dbc, DataSource* ds); - ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy); + public: + ADFS_PROXY() = default; + ADFS_PROXY(DBC* dbc, DataSource* ds); + ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy); #ifdef UNIT_TEST_BUILD - ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr auth_util); + ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr auth_util, + const std::shared_ptr& client); #endif - ~ADFS_PROXY() override; - bool connect( - const char* host, - const char* user, - const char* password, - const char* database, - unsigned int port, - const char* socket, - unsigned long flags) override; - -protected: - static std::unordered_map token_cache; - static std::mutex token_cache_mutex; - std::shared_ptr auth_util; - bool using_cached_token = false; + ~ADFS_PROXY() override; + bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, + const char* socket, unsigned long flags) override; + + protected: + static std::unordered_map token_cache; + static std::mutex token_cache_mutex; + std::shared_ptr auth_util; + std::shared_ptr saml_util; + bool using_cached_token = false; - static void clear_token_cache(); + static void clear_token_cache(); #ifdef UNIT_TEST_BUILD - // Allows for testing private/protected methods - friend class TEST_UTILS; + // Allows for testing private/protected methods + friend class TEST_UTILS; #endif }; #endif - diff --git a/driver/auth_util.cc b/driver/auth_util.cc index 4fa530e62..7ff6aac47 100644 --- a/driver/auth_util.cc +++ b/driver/auth_util.cc @@ -44,8 +44,7 @@ AUTH_UTIL::AUTH_UTIL(const char* region) { } this->rds_client = std::make_shared( - Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), - client_config); + Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); }; AUTH_UTIL::AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials) { @@ -59,7 +58,52 @@ AUTH_UTIL::AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials) this->rds_client = std::make_shared(credentials, client_config); } -std::string AUTH_UTIL::get_auth_token(const char* host, const char* region, unsigned int port, const char* user) { +std::pair AUTH_UTIL::get_auth_token(std::unordered_map& token_cache, + std::mutex& token_cache_mutex, const char* host, + const char* region, unsigned int port, const char* user, + unsigned int time_until_expiration, + bool force_generate_new_token) { + if (!host) { + host = ""; + } + if (!region) { + region = ""; + } + if (!user) { + user = ""; + } + + std::string auth_token; + const std::string cache_key = this->build_cache_key(host, region, port, user); + bool using_cached_token = false; + + { + std::unique_lock lock(token_cache_mutex); + + if (force_generate_new_token) { + token_cache.erase(cache_key); + } else { + // Search for token in cache + auto find_token = token_cache.find(cache_key); + if (find_token != token_cache.end()) { + TOKEN_INFO info = find_token->second; + if (info.is_expired()) { + token_cache.erase(cache_key); + } else { + using_cached_token = true; + return std::make_pair(info.token, using_cached_token); + } + } + } + + // Generate new token + auth_token = this->generate_token(host, region, port, user); + token_cache[cache_key] = TOKEN_INFO(auth_token, time_until_expiration); + } + return std::make_pair(auth_token, using_cached_token); +} + +std::string AUTH_UTIL::generate_token(const char* host, const char* region, unsigned int port, const char* user) { return this->rds_client->GenerateConnectAuthToken(host, region, port, user); } diff --git a/driver/auth_util.h b/driver/auth_util.h index f5e0d26cd..147fdd89b 100644 --- a/driver/auth_util.h +++ b/driver/auth_util.h @@ -63,12 +63,16 @@ class AUTH_UTIL { AUTH_UTIL(const char* region); AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials); ~AUTH_UTIL(); - - virtual std::string get_auth_token(const char* host, const char* region, unsigned int port, const char* user); + virtual std::pair get_auth_token(std::unordered_map& token_cache, + std::mutex& token_cache_mutex, const char* host, + const char* region, unsigned int port, const char* user, + unsigned int time_until_expiration, + bool force_generate_new_token = false); static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); private: std::shared_ptr rds_client; + virtual std::string generate_token(const char* host, const char* region, unsigned int port, const char* user); #ifdef UNIT_TEST_BUILD // Allows for testing private/protected methods diff --git a/driver/iam_proxy.cc b/driver/iam_proxy.cc index 47c5da0f5..93059a541 100644 --- a/driver/iam_proxy.cc +++ b/driver/iam_proxy.cc @@ -28,6 +28,7 @@ // http://www.gnu.org/licenses/gpl-2.0.html. #include +#include #include "driver.h" #include "iam_proxy.h" @@ -69,54 +70,6 @@ bool IAM_PROXY::change_user(const char* user, const char* passwd, const char* db return invoke_func_with_generated_token(f); } -std::string IAM_PROXY::get_auth_token( - const char* host, const char* region, unsigned int port, - const char* user, unsigned int time_until_expiration, - bool force_generate_new_token) { - - if (!host) { - host = ""; - } - if (!region) { - region = ""; - } - if (!user) { - user = ""; - } - - std::string auth_token; - std::string cache_key = this->auth_util->build_cache_key(host, region, port, user); - using_cached_token = false; - - { - std::unique_lock lock(token_cache_mutex); - - if (force_generate_new_token) { - token_cache.erase(cache_key); - } - else { - // Search for token in cache - auto find_token = token_cache.find(cache_key); - if (find_token != token_cache.end()) { - TOKEN_INFO info = find_token->second; - if (info.is_expired()) { - token_cache.erase(cache_key); - } else { - using_cached_token = true; - return info.token; - } - } - } - - // Generate new token - auth_token = this->auth_util->get_auth_token(host, region, port, user); - - token_cache[cache_key] = TOKEN_INFO(auth_token, time_until_expiration); - } - - return auth_token; -} - void IAM_PROXY::clear_token_cache() { std::unique_lock lock(token_cache_mutex); token_cache.clear(); @@ -125,7 +78,7 @@ void IAM_PROXY::clear_token_cache() { bool IAM_PROXY::invoke_func_with_generated_token(std::function func) { // Use user provided auth host if present, otherwise, use server host - const char *AUTH_HOST = ds->opt_AUTH_HOST ? (const char *)ds->opt_AUTH_HOST + const char *auth_host = ds->opt_AUTH_HOST ? (const char *)ds->opt_AUTH_HOST : (const char *)ds->opt_SERVER; // Go with default region if region is not provided. @@ -138,15 +91,17 @@ bool IAM_PROXY::invoke_func_with_generated_token(std::functionopt_PORT; } - std::string auth_token = this->get_auth_token(AUTH_HOST, region, iam_port, - (const char*)ds->opt_UID, ds->opt_AUTH_EXPIRATION); + std::string auth_token; + bool using_cached_token; + std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token( + token_cache, token_cache_mutex, auth_host, region, iam_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION); bool connect_result = func(auth_token.c_str()); if (!connect_result) { if (using_cached_token) { // Retry func with a fresh token - auth_token = this->get_auth_token(AUTH_HOST, region, iam_port, (const char*)ds->opt_UID, - ds->opt_AUTH_EXPIRATION, true); + std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(token_cache, token_cache_mutex, auth_host, region, iam_port, + ds->opt_UID, ds->opt_AUTH_EXPIRATION, true); if (func(auth_token.c_str())) { return true; } diff --git a/driver/iam_proxy.h b/driver/iam_proxy.h index 0d146f974..25be4437a 100644 --- a/driver/iam_proxy.h +++ b/driver/iam_proxy.h @@ -55,17 +55,10 @@ class IAM_PROXY : public CONNECTION_PROXY { bool change_user(const char* user, const char* passwd, const char* db) override; - - std::string get_auth_token( - const char* host,const char* region, unsigned int port, - const char* user, unsigned int time_until_expiration, - bool force_generate_new_token = false); - protected: static std::unordered_map token_cache; static std::mutex token_cache_mutex; std::shared_ptr auth_util; - bool using_cached_token = false; static void clear_token_cache(); diff --git a/driver/okta_proxy.cc b/driver/okta_proxy.cc index fffe9e437..c84e8a2fc 100644 --- a/driver/okta_proxy.cc +++ b/driver/okta_proxy.cc @@ -38,13 +38,13 @@ std::unordered_map OKTA_PROXY::token_cache; std::mutex OKTA_PROXY::token_cache_mutex; -OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds) : OKTA_PROXY(dbc, ds, nullptr) {}; +OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds) : OKTA_PROXY(dbc, ds, nullptr){}; OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) { this->next_proxy = next_proxy; std::string host{static_cast(ds->opt_IDP_ENDPOINT)}; host += ":" + std::to_string(ds->opt_IDP_PORT); - + const int client_connect_timeout = ds->opt_CLIENT_CONNECT_TIMEOUT; const int client_socket_timeout = ds->opt_CLIENT_SOCKET_TIMEOUT; const bool enable_ssl = ds->opt_ENABLE_SSL; @@ -59,7 +59,8 @@ bool OKTA_PROXY::connect(const char* host, const char* user, const char* passwor } bool OKTA_PROXY::invoke_func_with_fed_credentials(std::function func) { - const char* region = ds->opt_FED_AUTH_REGION ? static_cast(ds->opt_FED_AUTH_REGION) : Aws::Region::US_EAST_1; + const char* region = + ds->opt_FED_AUTH_REGION ? static_cast(ds->opt_FED_AUTH_REGION) : Aws::Region::US_EAST_1; std::string assertion; try { assertion = this->saml_util->get_saml_assertion(ds); @@ -75,21 +76,25 @@ bool OKTA_PROXY::invoke_func_with_fed_credentials(std::functionsaml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion); this->auth_util = std::make_shared(region, credentials); - const char* AUTH_HOST = - ds->opt_FED_AUTH_HOST ? static_cast(ds->opt_FED_AUTH_HOST) : static_cast(ds->opt_SERVER); + const char* auth_host = ds->opt_FED_AUTH_HOST ? static_cast(ds->opt_FED_AUTH_HOST) + : static_cast(ds->opt_SERVER); int auth_port = ds->opt_FED_AUTH_PORT; if (auth_port == UNDEFINED_PORT) { // Use regular port if user does not provide an alternative port for AWS authentication auth_port = ds->opt_PORT; } - std::string auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, auth_port, ds->opt_UID); + std::string auth_token; + bool using_cached_token; + std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token( + token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION); bool connect_result = func(auth_token.c_str()); if (!connect_result) { if (using_cached_token) { // Retry func with a fresh token - auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, auth_port, ds->opt_UID); + std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token( + token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION, true); if (func(auth_token.c_str())) { return true; } @@ -128,7 +133,8 @@ void OKTA_PROXY::clear_token_cache() { OKTA_SAML_UTIL::OKTA_SAML_UTIL(const std::shared_ptr& client) { this->http_client = client; } OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) { - this->http_client = std::make_shared("https://" + host, connect_timeout, socket_timeout, enable_ssl); + this->http_client = + std::make_shared("https://" + host, connect_timeout, socket_timeout, enable_ssl); } std::string OKTA_SAML_UTIL::get_saml_url(DataSource* ds) { @@ -145,7 +151,7 @@ std::string OKTA_SAML_UTIL::get_session_token(DataSource* ds) const { const nlohmann::json request_body = {{"username", username}, {"password", password}}; nlohmann::json res; try { - res = this->http_client->post(session_token_endpoint, request_body); + res = this->http_client->post(session_token_endpoint, request_body.dump(), "application/json"); } catch (SAML_HTTP_EXCEPTION& e) { const std::string error = "Failed to get session token from Okta : " + e.error_message() + ". Please verify your Okta credentials."; @@ -163,8 +169,8 @@ std::string OKTA_SAML_UTIL::get_saml_assertion(DataSource* ds) { try { res = this->http_client->get(this->get_saml_url(ds) + "?onetimetoken=" + token); } catch (SAML_HTTP_EXCEPTION& e) { - const std::string error = - "Failed to get SAML assertion from Okta : " + e.error_message() + ". Please verify your Okta identity provider configuration on AWS."; + const std::string error = "Failed to get SAML assertion from Okta : " + e.error_message() + + ". Please verify your Okta identity provider configuration on AWS."; throw SAML_HTTP_EXCEPTION(error); } const auto body = std::string(res); @@ -179,7 +185,7 @@ std::string OKTA_SAML_UTIL::get_saml_assertion(DataSource* ds) { return std::string(); }; - return f(SAML_RESPONSE_PATTERN); + return f(OKTA_REGEX::SAML_RESPONSE_PATTERN); } std::string OKTA_SAML_UTIL::replace_all(std::string str, const std::string& from, const std::string& to) { diff --git a/driver/okta_proxy.h b/driver/okta_proxy.h index 15d9097f4..a4ba1b0c7 100644 --- a/driver/okta_proxy.h +++ b/driver/okta_proxy.h @@ -36,7 +36,7 @@ #include "saml_http_client.h" #include "saml_util.h" -namespace { +namespace OKTA_REGEX { const std::regex SAML_RESPONSE_PATTERN(R"#(name=\"SAMLResponse\".+value=\"(.+)\"/\>)#", std::regex_constants::icase); } diff --git a/driver/saml_http_client.cc b/driver/saml_http_client.cc index 6dbe0eb10..013ca0ab0 100644 --- a/driver/saml_http_client.cc +++ b/driver/saml_http_client.cc @@ -30,6 +30,14 @@ #include "saml_http_client.h" #include +#include "mylog.h" + +#define MAX_REDIRECT_COUNT 20 + +#if !defined(WIN32) +#define stricmp strcasecmp +#endif + SAML_HTTP_CLIENT::SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) : host{std::move(host)}, connect_timeout(connect_timeout), socket_timeout(socket_timeout), enable_ssl(enable_ssl) {} @@ -42,28 +50,61 @@ httplib::Client SAML_HTTP_CLIENT::get_client() const { return client; } -nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::json& value) { + +nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const std::string& value, + const std::string& content_type) { httplib::Client client = this->get_client(); - if (auto res = client.Post(path.c_str(), value.dump(), "application/json")) { - if (res->status == httplib::StatusCode::OK_200) { - nlohmann::json json_object = nlohmann::json::parse(res->body); - return json_object; + auto res = client.Post(path.c_str(), value, content_type); + if (!res) { + throw SAML_HTTP_EXCEPTION("Post request failed"); + } + if (res->status == httplib::StatusCode::OK_200) { + if (stricmp(content_type.c_str(), "application/json") == 0) { + return nlohmann::json::parse(res->body); } + return res->body; + } - throw SAML_HTTP_EXCEPTION(std::to_string(res->status) + " " + res->reason); + int count = MAX_REDIRECT_COUNT; + while (res->status == httplib::StatusCode::Found_302 && count > 0) { + auto headers = res->headers; + auto pos = headers.find("location"); + if (pos != headers.end()) { + httplib::Headers cookies = {}; + std::string cookiestr; + for (auto const& x : headers) { + if (stricmp(x.first.c_str(), "Set-Cookie") == 0) { + cookiestr += x.second; + cookiestr += ";"; + } + } + cookies.emplace("Cookie", cookiestr); + + httplib::Client redirect_client = this->get_client(); + res = redirect_client.Get(pos->second.c_str(), cookies); + count--; + } + + if (res->status == httplib::StatusCode::OK_200) { + if (stricmp(content_type.c_str(), "application/json") == 0) { + return nlohmann::json::parse(res->body); + } + + return res->body; + } } - throw SAML_HTTP_EXCEPTION("Post request failed"); + throw SAML_HTTP_EXCEPTION(std::to_string(res->status) + " " + res->reason); } -nlohmann::json SAML_HTTP_CLIENT::get(const std::string& path) { +nlohmann::json SAML_HTTP_CLIENT::get(const std::string& path, const httplib::Headers& headers) { httplib::Client client = this->get_client(); client.set_follow_location(true); - if (auto res = client.Get(path.c_str())) { + + if (auto res = (headers.empty() ? client.Get(path) : client.Get(path, headers))) { if (res->status == httplib::StatusCode::OK_200) { return res->body; } throw SAML_HTTP_EXCEPTION(std::to_string(res->status) + " " + res->reason); } - throw SAML_HTTP_EXCEPTION("Get request failed"); } diff --git a/driver/saml_http_client.h b/driver/saml_http_client.h index 6c68eac49..f87b3063f 100644 --- a/driver/saml_http_client.h +++ b/driver/saml_http_client.h @@ -49,8 +49,8 @@ class SAML_HTTP_CLIENT { public: SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl); ~SAML_HTTP_CLIENT() = default; - virtual nlohmann::json post(const std::string& path, const nlohmann::json& value); - virtual nlohmann::json get(const std::string& path); + virtual nlohmann::json post(const std::string& path, const std::string& value, const std::string& content_type); + virtual nlohmann::json get(const std::string& path, const httplib::Headers& headers = {}); private: const std::string host; diff --git a/unit_testing/CMakeLists.txt b/unit_testing/CMakeLists.txt index 1f341e01a..a68e3df0f 100644 --- a/unit_testing/CMakeLists.txt +++ b/unit_testing/CMakeLists.txt @@ -55,6 +55,7 @@ add_executable( test_utils.h test_utils.cc + adfs_proxy_test.cc cluster_aware_metrics_test.cc efm_proxy_test.cc iam_proxy_test.cc diff --git a/unit_testing/adfs_proxy_test.cc b/unit_testing/adfs_proxy_test.cc new file mode 100644 index 000000000..38eabd0ab --- /dev/null +++ b/unit_testing/adfs_proxy_test.cc @@ -0,0 +1,118 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include + +#include "driver/adfs_proxy.h" +#include "mock_objects.h" +#include "test_utils.h" + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string TEST_HOST{"test_host"}; +const std::string TEST_USER{"test_user"}; +const std::string TEST_ENDPOINT{"test_endpoint"}; +const std::string TEST_IDP_USERNAME{"test_idp_username"}; +const std::string TEST_IDP_PASSWORD{"test_idp_password"}; +const std::string SIGN_IN_PAGE_URL = "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=urn:amazon:webservices"; +const nlohmann::json TEST_SIGN_IN_PAGE = + "
\n\n\n\n"; +const nlohmann::json TEST_SIGN_IN_RESPONSE = + "name=\"SAMLResponse\" " + "value=\"PHNhbWwycDpSZXNwb25zZSBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9zaWduaW4uYXdzLmFtYXpvbi5jb20vc2FtbCI+" + "PC9zYW1sMnA6UmVzcG9uc2U+\"/>"; +const nlohmann::json EXPECTED_ASSERTION = + "PHNhbWwycDpSZXNwb25zZSBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9zaWduaW4uYXdzLmFtYXpvbi5jb20vc2FtbCI+" + "PC9zYW1sMnA6UmVzcG9uc2U+"; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions options; + +class AdfsProxyTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + std::shared_ptr mock_auth_util; + std::shared_ptr mock_saml_http_client; + + static void SetUpTestSuite() { + Aws::InitAPI(options); + SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env); + } + + static void TearDownTestSuite() { + SQLFreeHandle(SQL_HANDLE_ENV, env); + Aws::ShutdownAPI(options); + } + + void SetUp() override { + SQLHDBC hdbc = nullptr; + SQLAllocHandle(SQL_HANDLE_DBC, env, &hdbc); + dbc = static_cast(hdbc); + ds = new DataSource(); + + ds->opt_AUTH_HOST.set_remove_brackets(to_sqlwchar_string(TEST_HOST).c_str(), TEST_HOST.size()); + ds->opt_UID.set_remove_brackets(to_sqlwchar_string(TEST_USER).c_str(), TEST_USER.size()); + ds->opt_IDP_USERNAME.set_remove_brackets(to_sqlwchar_string(TEST_IDP_USERNAME).c_str(), TEST_IDP_USERNAME.size()); + ds->opt_IDP_PASSWORD.set_remove_brackets(to_sqlwchar_string(TEST_IDP_PASSWORD).c_str(), TEST_IDP_PASSWORD.size()); + ds->opt_IDP_ENDPOINT.set_remove_brackets(to_sqlwchar_string(TEST_ENDPOINT).c_str(), TEST_ENDPOINT.size()); + + mock_saml_http_client = std::make_shared(TEST_ENDPOINT, 10, 10, true); + mock_auth_util = std::make_shared(); + } + + void TearDown() override { cleanup_odbc_handles(nullptr, dbc, ds); } +}; + +TEST_F(AdfsProxyTest, GetSAMLAssertion) { + const httplib::Headers response_body = {{"Set-Cookie", "cookie"}}; + const nlohmann::json expected_cookie = {{"Cookie", "cookie"}}; + const std::string expected_post_body = + "AuthMethod=FormsAuthentication&Password=test_idp_password&UserName=test_idp_username"; + const httplib::Headers header = {}; + + EXPECT_CALL(*mock_saml_http_client, get(StrEq(SIGN_IN_PAGE_URL), header)).WillOnce(Return(TEST_SIGN_IN_PAGE)); + EXPECT_CALL(*mock_saml_http_client, + post(StrEq(SIGN_IN_PAGE_URL), expected_post_body, "application/x-www-form-urlencoded")) + .WillOnce(Return(TEST_SIGN_IN_RESPONSE)); + + ADFS_SAML_UTIL adfs_util(mock_saml_http_client); + + const std::string assertion = adfs_util.get_saml_assertion(ds); + EXPECT_EQ(EXPECTED_ASSERTION, assertion); +} diff --git a/unit_testing/iam_proxy_test.cc b/unit_testing/iam_proxy_test.cc index 2b932750c..84086c4a7 100644 --- a/unit_testing/iam_proxy_test.cc +++ b/unit_testing/iam_proxy_test.cc @@ -54,7 +54,9 @@ class IamProxyTest : public testing::Test { DBC *dbc; DataSource *ds; MOCK_CONNECTION_PROXY *mock_connection_proxy; - std::shared_ptr mock_auth_util; + std::shared_ptr token_test_auth_util; + std::unordered_map token_cache; + std::mutex token_cache_mutex; static void SetUpTestSuite() { Aws::InitAPI(options); @@ -79,10 +81,11 @@ class IamProxyTest : public testing::Test { ds->opt_AUTH_EXPIRATION = TEST_EXPIRATION; mock_connection_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); - mock_auth_util = std::make_shared(); + token_test_auth_util = std::make_shared(); } void TearDown() override { + token_cache.clear(); cleanup_odbc_handles(nullptr, dbc, ds); } }; @@ -94,49 +97,53 @@ TEST_F(IamProxyTest, TokenExpiration) { std::this_thread::sleep_for(std::chrono::seconds(time_to_expire + 1)); EXPECT_TRUE(info.is_expired()); - delete mock_connection_proxy; } TEST_F(IamProxyTest, TokenGetsCachedAndRetrieved) { std::string cache_key = TEST_UTILS::build_cache_key( TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str()); - EXPECT_FALSE(TEST_UTILS::token_cache_contains_key(cache_key)); + EXPECT_FALSE(TEST_UTILS::token_cache_contains_key(token_cache, cache_key)); // We should only generate the token once. - EXPECT_CALL(*mock_auth_util, get_auth_token(_, _, _, _)) + EXPECT_CALL(*token_test_auth_util, generate_token(_, _, _, _)) .WillOnce(Return(TEST_TOKEN)); - IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, mock_auth_util); - - std::string token1 = iam_proxy.get_auth_token( + std::string token1; + bool use_cached_bool; + std::tie(token1, use_cached_bool) = token_test_auth_util->get_auth_token(token_cache, token_cache_mutex, TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), 100); - EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key)); + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(token_cache, cache_key)); + EXPECT_FALSE(use_cached_bool); // This 2nd call to get_auth_token() will retrieve the cached token. - std::string token2 = iam_proxy.get_auth_token( - TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), 100); + std::string token2; + std::tie(token2, use_cached_bool) = token_test_auth_util->get_auth_token( + token_cache, token_cache_mutex, TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), 100); EXPECT_EQ(TEST_TOKEN, token1); EXPECT_TRUE(token1 == token2); - - TEST_UTILS::clear_token_cache(iam_proxy); + EXPECT_TRUE(use_cached_bool); + delete mock_connection_proxy; } TEST_F(IamProxyTest, MultipleCachedTokens) { // Two separate tokens should be generated. - EXPECT_CALL(*mock_auth_util, get_auth_token(_, TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str())) + EXPECT_CALL(*token_test_auth_util, generate_token(_, TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str())) .WillOnce(Return(TEST_TOKEN)) .WillOnce(Return(TEST_TOKEN)); - IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, mock_auth_util); const char *host2 = "test_host2"; - iam_proxy.get_auth_token( + std::string token1; + bool use_cached_bool; + std::tie(token1, use_cached_bool) = token_test_auth_util->get_auth_token(token_cache, token_cache_mutex, TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), 100); - iam_proxy.get_auth_token( - host2, TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), 100); + std::tie(token1, use_cached_bool) = token_test_auth_util->get_auth_token( + token_cache, token_cache_mutex, host2, TEST_REGION.c_str(), + TEST_PORT, TEST_USER.c_str(), 100); + std::string cache_key1 = TEST_UTILS::build_cache_key( TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str()); @@ -144,59 +151,55 @@ TEST_F(IamProxyTest, MultipleCachedTokens) { host2, TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str()); EXPECT_NE(cache_key1, cache_key2); - - EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key1)); - EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key2)); - - TEST_UTILS::clear_token_cache(iam_proxy); + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(token_cache, cache_key1)); + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(token_cache, cache_key2)); + delete mock_connection_proxy; } TEST_F(IamProxyTest, RegenerateTokenAfterExpiration) { // We will generate the token twice because the first token will expire before the 2nd call to get_auth_token(). - EXPECT_CALL(*mock_auth_util, - get_auth_token(TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str())) + EXPECT_CALL(*token_test_auth_util, + generate_token(TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str())) .WillOnce(Return(TEST_TOKEN)) .WillOnce(Return(TEST_TOKEN)); - IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, mock_auth_util); - - const unsigned int time_to_expire = 5; - iam_proxy.get_auth_token( - TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), time_to_expire); + constexpr unsigned int time_to_expire = 5; + std::string token; + bool use_cached_bool; + std::tie(token, use_cached_bool) = + token_test_auth_util->get_auth_token(token_cache, token_cache_mutex, TEST_HOST.c_str(), TEST_REGION.c_str(), + TEST_PORT, TEST_USER.c_str(), time_to_expire); std::string cache_key = TEST_UTILS::build_cache_key( TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str()); - EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key)); + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(token_cache, cache_key)); // Wait for first token to expire. std::this_thread::sleep_for(std::chrono::seconds(time_to_expire + 1)); - iam_proxy.get_auth_token( - TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), time_to_expire); + std::tie(token, use_cached_bool) = + token_test_auth_util->get_auth_token(token_cache, token_cache_mutex, TEST_HOST.c_str(), TEST_REGION.c_str(), + TEST_PORT, TEST_USER.c_str(), time_to_expire); - EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key)); - - TEST_UTILS::clear_token_cache(iam_proxy); + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(token_cache, cache_key)); + delete mock_connection_proxy; } TEST_F(IamProxyTest, ForceGenerateNewToken) { // We expect a token to be generated twice because the 2nd call to get_auth_token forces a fresh token. - EXPECT_CALL(*mock_auth_util, - get_auth_token(TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str())) + EXPECT_CALL(*token_test_auth_util, + generate_token(TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str())) .WillOnce(Return(TEST_TOKEN)) .WillOnce(Return(TEST_TOKEN)); - IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, mock_auth_util); - - const unsigned int time_to_expire = 100; - iam_proxy.get_auth_token( + constexpr unsigned int time_to_expire = 100; + token_test_auth_util->get_auth_token(token_cache, token_cache_mutex, TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), time_to_expire); // 2nd call to get_auth_token should still generate a new token because we are forcing it - // even though the first token has not yet expired - iam_proxy.get_auth_token( + // even though the first token has not yet expired + token_test_auth_util->get_auth_token(token_cache, token_cache_mutex, TEST_HOST.c_str(), TEST_REGION.c_str(), TEST_PORT, TEST_USER.c_str(), time_to_expire, true); - - TEST_UTILS::clear_token_cache(iam_proxy); + delete mock_connection_proxy; } TEST_F(IamProxyTest, RetryConnectionWithFreshTokenAfterFailingWithCachedToken) { @@ -209,11 +212,12 @@ TEST_F(IamProxyTest, RetryConnectionWithFreshTokenAfterFailingWithCachedToken) { .WillOnce(Return(true)); // Only called twice because one of the above connection attempts used a cached token. - EXPECT_CALL(*mock_auth_util, get_auth_token(_, _, _, _)) + EXPECT_CALL(*token_test_auth_util, + generate_token(_, _, _, _)) .WillOnce(Return(TEST_TOKEN)) .WillOnce(Return(TEST_TOKEN)); - IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, mock_auth_util); + IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, token_test_auth_util); // First successful connection to get a token cached. bool ret = iam_proxy.connect(TEST_HOST.c_str(), TEST_USER.c_str(), "", "", TEST_PORT, "", 0); @@ -223,20 +227,17 @@ TEST_F(IamProxyTest, RetryConnectionWithFreshTokenAfterFailingWithCachedToken) { // After failing that attempt it will try again with a fresh token and succeed. ret = iam_proxy.connect(TEST_HOST.c_str(), TEST_USER.c_str(), "", "", TEST_PORT, "", 0); EXPECT_TRUE(ret); - - TEST_UTILS::clear_token_cache(iam_proxy); } TEST_F(IamProxyTest, UseRegularPortWhenAuthPortIsNotProvided) { ds->opt_AUTH_PORT = UNDEFINED_PORT; // Verify that we generate the token with the regular port when we do not have auth port. - EXPECT_CALL(*mock_auth_util, get_auth_token(_, _, TEST_PORT, _)) + EXPECT_CALL(*token_test_auth_util, + generate_token(_, _, TEST_PORT, _)) .WillOnce(Return(TEST_TOKEN)); - IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, mock_auth_util); + IAM_PROXY iam_proxy(dbc, ds, mock_connection_proxy, token_test_auth_util); iam_proxy.connect(TEST_HOST.c_str(), TEST_USER.c_str(), "", "", TEST_PORT, "", 0); - - TEST_UTILS::clear_token_cache(iam_proxy); } diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index 2185c2856..80ebce772 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -31,11 +31,11 @@ #define __MOCKOBJECTS_H__ #include +#include #include #include "driver/connection_proxy.h" #include "driver/failover.h" -#include "driver/iam_proxy.h" #include "driver/saml_http_client.h" #include "driver/monitor_thread_container.h" #include "driver/monitor_service.h" @@ -225,14 +225,13 @@ class MOCK_SECRETS_MANAGER_CLIENT : public Aws::SecretsManager::SecretsManagerCl class MOCK_AUTH_UTIL : public AUTH_UTIL { public: MOCK_AUTH_UTIL() : AUTH_UTIL() {}; - - MOCK_METHOD(std::string, get_auth_token, (const char*, const char*, unsigned int, const char*)); + MOCK_METHOD(std::string, generate_token, (const char*, const char*, unsigned int, const char*)); }; class MOCK_SAML_HTTP_CLIENT : public SAML_HTTP_CLIENT { public: MOCK_SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) : SAML_HTTP_CLIENT(host, connect_timeout, socket_timeout, enable_ssl) {}; - MOCK_METHOD(nlohmann::json, post, (const std::string&, const nlohmann::json&)); - MOCK_METHOD(nlohmann::json, get, (const std::string&)); + MOCK_METHOD(nlohmann::json, post, (const std::string&, const std::string&, const std::string&)); + MOCK_METHOD(nlohmann::json, get, (const std::string&, const httplib::Headers&)); }; #endif /* __MOCKOBJECTS_H__ */ diff --git a/unit_testing/okta_proxy_test.cc b/unit_testing/okta_proxy_test.cc index 0f1051eb5..a5f2862be 100644 --- a/unit_testing/okta_proxy_test.cc +++ b/unit_testing/okta_proxy_test.cc @@ -114,7 +114,8 @@ TEST_F(OktaProxyTest, GetSAMLURL) { TEST_F(OktaProxyTest, GetSessionToken) { const nlohmann::json request_body = {{"username", "test_idp_username"}, {"password", "test_idp_password"}}; - EXPECT_CALL(*mock_saml_http_client, post(StrEq("/api/v1/authn"), request_body)).WillOnce(Return(TEST_SESSION_TOKEN)); + EXPECT_CALL(*mock_saml_http_client, post(StrEq("/api/v1/authn"), request_body.dump(), "application/json")) + .WillOnce(Return(TEST_SESSION_TOKEN)); OKTA_SAML_UTIL okta_util(mock_saml_http_client); const std::string token = okta_util.get_session_token(ds); @@ -126,8 +127,8 @@ TEST_F(OktaProxyTest, GetSAMLAssertion) { "/app/amazon_aws/test_app/sso/saml?onetimetoken=20111sTEtWA8_kJzLH-JQ87ScdVRZOa6NcaX9-letters"; const nlohmann::json request_body = {{"username", "test_idp_username"}, {"password", "test_idp_password"}}; - EXPECT_CALL(*mock_saml_http_client, post(StrEq("/api/v1/authn"), request_body)).WillOnce(Return(TEST_SESSION_TOKEN)); - EXPECT_CALL(*mock_saml_http_client, get(_)).WillOnce(Return(TEST_ASSERTION)); + EXPECT_CALL(*mock_saml_http_client, post(StrEq("/api/v1/authn"), request_body.dump(), "application/json")).WillOnce(Return(TEST_SESSION_TOKEN)); + EXPECT_CALL(*mock_saml_http_client, get(_, _)).WillOnce(Return(TEST_ASSERTION)); OKTA_SAML_UTIL okta_util(mock_saml_http_client); diff --git a/unit_testing/test_utils.cc b/unit_testing/test_utils.cc index 14ac319ac..6d0709d8b 100644 --- a/unit_testing/test_utils.cc +++ b/unit_testing/test_utils.cc @@ -115,12 +115,8 @@ std::string TEST_UTILS::build_cache_key(const char* host, const char* region, un return AUTH_UTIL::build_cache_key(host, region, port, user); } -bool TEST_UTILS::token_cache_contains_key(std::string cache_key) { - return IAM_PROXY::token_cache.find(cache_key) != IAM_PROXY::token_cache.end(); -} - -void TEST_UTILS::clear_token_cache(IAM_PROXY& iam_proxy) { - iam_proxy.clear_token_cache(); +bool TEST_UTILS::token_cache_contains_key(std::unordered_map token_cache, std::string cache_key) { + return token_cache.find(cache_key) != token_cache.end(); } std::map, Aws::Utils::Json::JsonValue>& TEST_UTILS::get_secrets_cache() { diff --git a/unit_testing/test_utils.h b/unit_testing/test_utils.h index 292447af7..a2766633b 100644 --- a/unit_testing/test_utils.h +++ b/unit_testing/test_utils.h @@ -58,8 +58,7 @@ class TEST_UTILS { static size_t get_map_size(std::shared_ptr container); static std::list> get_contexts(std::shared_ptr monitor); static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); - static bool token_cache_contains_key(std::string cache_key); - static void clear_token_cache(IAM_PROXY& iam_proxy); + static bool token_cache_contains_key(std::unordered_map token_cache, std::string cache_key); static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); static bool try_parse_region_from_secret(std::string secret, std::string& region); static bool is_dns_pattern_valid(std::string host);