Skip to content

Commit

Permalink
Merge pull request #516 from v-xiangs/v-xiangs-6.2.2
Browse files Browse the repository at this point in the history
Update ADAL4J and AKV
  • Loading branch information
xiangyushawn authored Sep 29, 2017
2 parents e3857f5 + f1ab1f9 commit ea83c59
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 128 deletions.
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@
<dependency>
<groupId>com.microsoft.azure</groupId>
<artifactId>azure-keyvault</artifactId>
<version>0.9.7</version>
<version>1.0.0</version>
<optional>true</optional>
</dependency>

<dependency>
<groupId>com.microsoft.azure</groupId>
<artifactId>adal4j</artifactId>
<version>1.1.3</version>
<version>1.2.0</version>
<optional>true</optional>
</dependency>

Expand Down
76 changes: 40 additions & 36 deletions src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,61 @@

package com.microsoft.sqlserver.jdbc;

import java.util.Map;

import org.apache.http.Header;
import org.apache.http.message.BasicHeader;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import com.microsoft.aad.adal4j.AuthenticationContext;
import com.microsoft.aad.adal4j.AuthenticationResult;
import com.microsoft.aad.adal4j.ClientCredential;
import com.microsoft.azure.keyvault.authentication.KeyVaultCredentials;
import com.microsoft.windowsazure.core.pipeline.filter.ServiceRequestContext;

/**
*
* An implementation of ServiceClientCredentials that supports automatic bearer token refresh.
*
*/
class KeyVaultCredential extends KeyVaultCredentials {

// this is the only supported access token type
// https://msdn.microsoft.com/en-us/library/azure/dn645538.aspx
private final String accessTokenType = "Bearer";

SQLServerKeyVaultAuthenticationCallback authenticationCallback = null;
String clientId = null;
String clientKey = null;
String accessToken = null;

KeyVaultCredential(SQLServerKeyVaultAuthenticationCallback authenticationCallback) {
this.authenticationCallback = authenticationCallback;
KeyVaultCredential(String clientId,
String clientKey) {
this.clientId = clientId;
this.clientKey = clientKey;
}

/**
* Authenticates the service request
*
* @param request
* the ServiceRequestContext
* @param challenge
* used to get the accessToken
* @return BasicHeader
*/
@Override
public Header doAuthenticate(ServiceRequestContext request,
Map<String, String> challenge) {
assert null != challenge;

String authorization = challenge.get("authorization");
String resource = challenge.get("resource");

accessToken = authenticationCallback.getAccessToken(authorization, resource, "");
return new BasicHeader("Authorization", accessTokenType + " " + accessToken);
public String doAuthenticate(String authorization,
String resource,
String scope) {
AuthenticationResult token = getAccessTokenFromClientCredentials(authorization, resource, clientId, clientKey);
return token.getAccessToken();
}

void setAccessToken(String accessToken) {
this.accessToken = accessToken;
}
private static AuthenticationResult getAccessTokenFromClientCredentials(String authorization,
String resource,
String clientId,
String clientKey) {
AuthenticationContext context = null;
AuthenticationResult result = null;
ExecutorService service = null;
try {
service = Executors.newFixedThreadPool(1);
context = new AuthenticationContext(authorization, false, service);
ClientCredential credentials = new ClientCredential(clientId, clientKey);
Future<AuthenticationResult> future = context.acquireToken(resource, credentials, null);
result = future.get();
}
catch (Exception e) {
throw new RuntimeException(e);
}
finally {
service.shutdown();
}

if (result == null) {
throw new RuntimeException("authentication result was null");
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;

import org.apache.http.impl.client.HttpClientBuilder;

import com.microsoft.azure.keyvault.KeyVaultClient;
import com.microsoft.azure.keyvault.KeyVaultClientImpl;
import com.microsoft.azure.keyvault.models.KeyBundle;
import com.microsoft.azure.keyvault.models.KeyOperationResult;
import com.microsoft.azure.keyvault.models.KeyVerifyResult;
import com.microsoft.azure.keyvault.webkey.JsonWebKeyEncryptionAlgorithm;
import com.microsoft.azure.keyvault.webkey.JsonWebKeySignatureAlgorithm;

/**
Expand Down Expand Up @@ -66,26 +63,19 @@ public String getName() {
}

/**
* Constructor that takes a callback function to authenticate to AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key
* Vault.
* Constructor that authenticates to AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key
*
* @param authenticationCallback
* - Callback function used for authenticating to AAD.
* @param executorService
* - The ExecutorService used to create the keyVaultClient
* @param clientId
* Identifier of the client requesting the token.
* @param clientKey
* Key of the client requesting the token.
* @throws SQLServerException
* when an error occurs
*/
public SQLServerColumnEncryptionAzureKeyVaultProvider(SQLServerKeyVaultAuthenticationCallback authenticationCallback,
ExecutorService executorService) throws SQLServerException {
if (null == authenticationCallback) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
Object[] msgArgs1 = {"SQLServerKeyVaultAuthenticationCallback"};
throw new SQLServerException(form.format(msgArgs1), null);
}
credential = new KeyVaultCredential(authenticationCallback);
HttpClientBuilder builder = HttpClientBuilder.create();
keyVaultClient = new KeyVaultClientImpl(builder, executorService, credential);
public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId,
String clientKey) throws SQLServerException {
credential = new KeyVaultCredential(clientId, clientKey);
keyVaultClient = new KeyVaultClient(credential);
}

/**
Expand Down Expand Up @@ -308,7 +298,7 @@ public byte[] encryptColumnEncryptionKey(String masterKeyPath,
byte dataToSign[] = md.digest();

// Sign the hash
byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath);
byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath);

if (signedHash.length != keySizeInBytes) {
throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null);
Expand Down Expand Up @@ -433,14 +423,10 @@ private byte[] AzureKeyVaultWrap(String masterKeyPath,
throw new SQLServerException(SQLServerException.getErrString("R_CEKNull"), null);
}

KeyOperationResult wrappedKey = null;
try {
wrappedKey = keyVaultClient.wrapKeyAsync(masterKeyPath, encryptionAlgorithm, columnEncryptionKey).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_EncryptCEKError"), e);
}
return wrappedKey.getResult();
JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm);
KeyOperationResult wrappedKey = keyVaultClient.wrapKey(masterKeyPath, jsonEncryptionAlgorithm, columnEncryptionKey);

return wrappedKey.result();
}

/**
Expand All @@ -466,14 +452,10 @@ private byte[] AzureKeyVaultUnWrap(String masterKeyPath,
throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedCEK"), null);
}

KeyOperationResult unwrappedKey;
try {
unwrappedKey = keyVaultClient.unwrapKeyAsync(masterKeyPath, encryptionAlgorithm, encryptedColumnEncryptionKey).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_DecryptCEKError"), e);
}
return unwrappedKey.getResult();
JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm);
KeyOperationResult unwrappedKey = keyVaultClient.unwrapKey(masterKeyPath, jsonEncryptionAlgorithm, encryptedColumnEncryptionKey);

return unwrappedKey.result();
}

/**
Expand All @@ -490,14 +472,9 @@ private byte[] AzureKeyVaultSignHashedData(byte[] dataToSign,
String masterKeyPath) throws SQLServerException {
assert ((null != dataToSign) && (0 != dataToSign.length));

KeyOperationResult signedData = null;
try {
signedData = keyVaultClient.signAsync(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToSign).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_GenerateSignature"), e);
}
return signedData.getResult();
KeyOperationResult signedData = keyVaultClient.sign(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToSign);

return signedData.result();
}

/**
Expand All @@ -516,15 +493,9 @@ private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify,
assert ((null != dataToVerify) && (0 != dataToVerify.length));
assert ((null != signature) && (0 != signature.length));

boolean valid = false;
try {
valid = keyVaultClient.verifyAsync(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToVerify, signature).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_VerifySignature"), e);
}
KeyVerifyResult valid = keyVaultClient.verify(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToVerify, signature);

return valid;
return valid.value();
}

/**
Expand All @@ -537,21 +508,22 @@ private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify,
* when an error occurs
*/
private int getAKVKeySize(String masterKeyPath) throws SQLServerException {
KeyBundle retrievedKey = keyVaultClient.getKey(masterKeyPath);

KeyBundle retrievedKey = null;
try {
retrievedKey = keyVaultClient.getKeyAsync(masterKeyPath).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_GetAKVKeySize"), e);
if (null == retrievedKey) {
String[] keyTokens = masterKeyPath.split("/");

MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyNotFound"));
Object[] msgArgs = {keyTokens[keyTokens.length - 1]};
throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
}

if (!"RSA".equalsIgnoreCase(retrievedKey.getKey().getKty()) && !"RSA-HSM".equalsIgnoreCase(retrievedKey.getKey().getKty())) {
if (!"RSA".equalsIgnoreCase(retrievedKey.key().kty().toString()) && !"RSA-HSM".equalsIgnoreCase(retrievedKey.key().kty().toString())) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NonRSAKey"));
Object[] msgArgs = {retrievedKey.getKey().getKty()};
Object[] msgArgs = {retrievedKey.key().kty().toString()};
throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
}

return retrievedKey.getKey().getN().length;
return retrievedKey.key().n().length;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -383,5 +383,6 @@ protected Object[][] getContents() {
{"R_kerberosLoginFailed", "Kerberos Login failed: {0} due to {1} ({2})"},
{"R_StoredProcedureNotFound", "Could not find stored procedure ''{0}''."},
{"R_jaasConfigurationNamePropertyDescription", "Login configuration file for Kerberos authentication."},
{"R_AKVKeyNotFound", "Key not found: {0}"},
};
}

0 comments on commit ea83c59

Please sign in to comment.