Skip to content

Commit

Permalink
HCK-8622: refactored connection code and encahnced username, host and…
Browse files Browse the repository at this point in the history
… authmethod logs
  • Loading branch information
WilhelmWesser committed Nov 1, 2024
1 parent b1df057 commit cf30fab
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 161 deletions.
181 changes: 20 additions & 161 deletions reverse_engineering/databaseService/databaseService.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const { getObjectsFromDatabase, getNewConnectionClientByDb } = require('./helper
const msal = require('@azure/msal-node');
const getSampleDocSize = require('../helpers/getSampleDocSize');
const { logAuthTokenInfo } = require('../helpers/logInfo');
const { getConnection } = require('./helpers/connection');

const QUERY_REQUEST_TIMEOUT = 60000;

Expand All @@ -17,12 +18,6 @@ const getConnectionClient = async (connectionInfo, logger) => {
const tenantId = connectionInfo.connectionTenantId || connectionInfo.tenantId || 'common';
const queryRequestTimeout = Number(connectionInfo.queryRequestTimeout) || QUERY_REQUEST_TIMEOUT;

logger.log(
'info',
`hostname: ${hostName}, username: ${userName}, auth method: ${connectionInfo.authMethod}`,
'Auth info',
);

const commonConfig = {
server: connectionInfo.host,
port: +connectionInfo.port,
Expand All @@ -38,63 +33,20 @@ const getConnectionClient = async (connectionInfo, logger) => {
const clientId = '0dc36597-bc44-49f8-a4a7-ae5401959b85';
const redirectUri = 'http://localhost:8080';

switch (connectionInfo.authMethod) {
case 'Username / Password':
return sql.connect({
...commonConfig,
...credentialsConfig,
options: {
encrypt: true,
enableArithAbort: true,
},
});
case 'Username / Password (Windows)':
return sql.connect({
...commonConfig,
...credentialsConfig,
domain: connectionInfo.userDomain,
options: {
encrypt: false,
enableArithAbort: true,
},
});
case 'Azure Active Directory (MFA)':
const token = await getToken({ connectionInfo, tenantId, clientId, redirectUri, logger });
logAuthTokenInfo({ token, logger });
return sql.connect({
...commonConfig,
options: {
encrypt: true,
enableArithAbort: true,
},
authentication: {
type: 'azure-active-directory-access-token',
options: {
token,
},
},
});
case 'Azure Active Directory (Username / Password)':
return sql.connect({
...commonConfig,
...credentialsConfig,
options: {
encrypt: true,
enableArithAbort: true,
},
authentication: {
type: 'azure-active-directory-password',
options: {
userName: connectionInfo.userName,
password: connectionInfo.userPassword,
tenantId,
clientId,
},
},
});
}
const connection = getConnection({
type: connectionInfo.authMethod,
data: {
connectionInfo,
commonConfig,
credentialsConfig,
tenantId,
clientId,
redirectUri,
logger,
},
});

return await sql.connect(connectionInfo.connectionString);
return connection.connect();
};

const isEmail = name => /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(name || '');
Expand Down Expand Up @@ -513,98 +465,13 @@ const mapResponse = async (response = {}) => {
return (await response).recordset;
};

const getTokenByMSAL = async ({ connectionInfo, redirectUri, clientId, tenantId, logger }) => {
try {
const pca = new msal.PublicClientApplication(getAuthConfig(clientId, tenantId, logger.log));
const tokenRequest = {
code: connectionInfo?.externalBrowserQuery?.code || '',
scopes: ['https://database.windows.net//.default'],
redirectUri,
codeVerifier: connectionInfo?.proofKey,
clientInfo: connectionInfo?.externalBrowserQuery?.client_info || '',
};

const responseData = await pca.acquireTokenByCode(tokenRequest);

return responseData.accessToken;
} catch (error) {
logger.log('error', { message: error.message, stack: error.stack, error }, 'MFA MSAL auth error');
return '';
}
};

const getAgent = (reject, cert, key) => {
return new https.Agent({ cert, key, rejectUnauthorized: !!reject });
};

const getTokenByAxios = async ({ connectionInfo, tenantId, redirectUri, clientId, logger, agent }) => {
try {
const params = new URLSearchParams();
params.append('code', connectionInfo?.externalBrowserQuery?.code || '');
params.append('client_id', clientId);
params.append('redirect_uri', redirectUri);
params.append('grant_type', 'authorization_code');
params.append('code_verifier', connectionInfo?.proofKey);
params.append('resource', 'https://database.windows.net/');

const responseData = await axios.post(`https://login.microsoftonline.com/${tenantId}/oauth2/token`, params, {
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
...(agent && { httpsAgent: agent }),
});

return responseData?.data?.access_token || '';
} catch (error) {
logger.log('error', { message: error.message, stack: error.stack, error }, 'MFA Axios auth error');
return '';
}
};

const getTokenByAxiosExtended = params => {
return getTokenByAxios({ ...params, agent: getAgent() });
};

const getToken = async ({ connectionInfo, tenantId, clientId, redirectUri, logger }) => {
const axiosExtendedToken = await getTokenByAxiosExtended({
connectionInfo,
clientId,
redirectUri,
tenantId,
logger,
});
if (axiosExtendedToken) {
return axiosExtendedToken;
}

const msalToken = await getTokenByMSAL({ connectionInfo, clientId, redirectUri, tenantId, logger });
if (msalToken) {
return msalToken;
}

const axiosToken = await getTokenByAxios({ connectionInfo, clientId, redirectUri, tenantId, logger });
if (axiosToken) {
return axiosToken;
}

return;
};
async function getTableRowCount(tableSchema, tableName, currentDbConnectionClient) {
const rowCountQuery = `SELECT COUNT(*) as rowsCount FROM [${tableSchema}].[${tableName}]`;
const rowCountResponse = await currentDbConnectionClient.query(rowCountQuery);
const rowCount = rowCountResponse?.recordset[0]?.rowsCount;

const getAuthConfig = (clientId, tenantId, logger) => ({
system: {
loggerOptions: {
loggerCallback(loglevel, message) {
logger(message);
},
piiLoggingEnabled: false,
logLevel: msal.LogLevel.Verbose,
},
},
auth: {
clientId,
authority: `https://login.microsoftonline.com/${tenantId}`,
},
});
return rowCount;
}

module.exports = {
getConnectionClient,
Expand All @@ -629,11 +496,3 @@ module.exports = {
queryDistribution,
getPartitions,
};

async function getTableRowCount(tableSchema, tableName, currentDbConnectionClient) {
const rowCountQuery = `SELECT COUNT(*) as rowsCount FROM [${tableSchema}].[${tableName}]`;
const rowCountResponse = await currentDbConnectionClient.query(rowCountQuery);
const rowCount = rowCountResponse?.recordset[0]?.rowsCount;

return rowCount;
}
Loading

0 comments on commit cf30fab

Please sign in to comment.