Skip to content

Commit

Permalink
fix(plc4j): Adjust encoding/decoding logic to support mixed key lengths.
Browse files Browse the repository at this point in the history
As found during the tests of MX OPC server, our encoding logic did not work properly when client and server used different key lengths.
Changes introduced in this commit address this issue, and add few additional tests to confirm valid behavior.

Relates to #1682 and #1802, whereas second issue lead to discovery of the bug.

Signed-off-by: Łukasz Dywicki <luke@code-house.org>
  • Loading branch information
splatch committed Oct 17, 2024
1 parent 88709be commit cb8153a
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ protected int decrypt(WriteBufferByteBased chunkBuffer, Chunk chunk, int message
for (int block = 0; block < blockCount; block++) {
int pos = block * chunk.getCipherTextBlockSize();

bodyLength += cipher.doFinal(encrypted, pos, chunk.getCipherTextBlockSize(), plainText, pos);
bodyLength += cipher.doFinal(encrypted, pos, chunk.getCipherTextBlockSize(), plainText, bodyLength);
}

chunkBuffer.setPos(bodyStart);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ private short getPaddingSize(WriteBufferByteBased chunkBuffer, Chunk chunk, int
int paddingEnd = messageLength - chunk.getSignatureSize() - encryptionOverhead - chunk.getPaddingOverhead();
byte[] padding = chunkBuffer.getBytes(paddingEnd, paddingEnd + chunk.getPaddingOverhead());
if (padding.length > 2) { // cipher block size exceeds 256 bytes
return (short)(((padding[1] & 0xFF) << 8) | (padding[0] & 0xFF));
int paddingSize = ((padding[1] & 0xFF) << 8) | (padding[0] & 0xFF);
return (short) (paddingSize & 0xFFFF);
}
return padding[0];
return (short) (padding[0] & 0xFF);
}

protected abstract void verify(WriteBufferByteBased buffer, Chunk chunk, int messageLength) throws Exception;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public Chunk create(boolean asymmetric, boolean encrypted, boolean signed, Secur
int serverCertificateThumbprint = asymmetric ? certificateThumbprint(remoteCertificate).length : 0;

int asymmetricSecurityHarderSize = (12 + securityPolicy.getSecurityPolicyUri().length() + localCertificateSize + serverCertificateThumbprint);
int asymmetricCipherTextBlockSize = asymmetric ? (localAsymmetricKeyLength + 7) / 8 : 0;
int plainTextTextBlockSize = asymmetric ? (localAsymmetricKeyLength + 7) / 8 : 0;
int asymmetricCipherTextBlockSize = asymmetric ? (remoteAsymmetricKeyLength + 7) / 8 : 0;
int plainTextBlockSize = asymmetric ? (remoteAsymmetricKeyLength + 7) / 8 : 0;

int cipherTextBlockSize = asymmetric ? asymmetricCipherTextBlockSize : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1);

Expand All @@ -81,8 +81,8 @@ public Chunk create(boolean asymmetric, boolean encrypted, boolean signed, Secur
return new Chunk(
asymmetric ? asymmetricSecurityHarderSize : SYMMETRIC_SECURITY_HEADER_SIZE,
cipherTextBlockSize,
asymmetric ? plainTextTextBlockSize - 11 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((remoteAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
asymmetric ? plainTextBlockSize - 11 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((localAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
(int) limits.getSendBufferSize(),
asymmetric,
encryption,
Expand All @@ -93,8 +93,8 @@ public Chunk create(boolean asymmetric, boolean encrypted, boolean signed, Secur
// 12 + 56 + 674 + 20
asymmetric ? asymmetricSecurityHarderSize : SYMMETRIC_SECURITY_HEADER_SIZE,
cipherTextBlockSize,
asymmetric ? plainTextTextBlockSize - 42 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((remoteAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
asymmetric ? plainTextBlockSize - 42 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((localAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
(int) limits.getSendBufferSize(),
asymmetric,
encryption,
Expand All @@ -104,8 +104,8 @@ public Chunk create(boolean asymmetric, boolean encrypted, boolean signed, Secur
return new Chunk(
asymmetric ? asymmetricSecurityHarderSize : SYMMETRIC_SECURITY_HEADER_SIZE,
cipherTextBlockSize,
asymmetric ? plainTextTextBlockSize - 42 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((remoteAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
asymmetric ? plainTextBlockSize - 42 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((localAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
(int) limits.getSendBufferSize(),
asymmetric,
encryption,
Expand All @@ -115,8 +115,8 @@ public Chunk create(boolean asymmetric, boolean encrypted, boolean signed, Secur
return new Chunk(
asymmetric ? asymmetricSecurityHarderSize : SYMMETRIC_SECURITY_HEADER_SIZE,
cipherTextBlockSize,
asymmetric ? plainTextTextBlockSize - 42 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((remoteAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
asymmetric ? plainTextBlockSize - 42 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((localAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
(int) limits.getSendBufferSize(),
asymmetric,
encryption,
Expand All @@ -126,8 +126,8 @@ public Chunk create(boolean asymmetric, boolean encrypted, boolean signed, Secur
return new Chunk(
asymmetric ? asymmetricSecurityHarderSize : SYMMETRIC_SECURITY_HEADER_SIZE,
cipherTextBlockSize,
asymmetric ? plainTextTextBlockSize - 66 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((remoteAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
asymmetric ? plainTextBlockSize - 66 : (encrypted ? securityPolicy.getEncryptionBlockSize() : 1),
asymmetric ? ((localAsymmetricKeyLength + 7) / 8) : securityPolicy.getSymmetricSignatureSize(),
(int) limits.getSendBufferSize(),
asymmetric,
encryption,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,26 @@

package org.apache.plc4x.java.opcua.context;

import static java.util.Map.entry;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;

import java.security.Key;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.Map.Entry;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.plc4x.java.opcua.TestCertificateGenerator;
import org.apache.plc4x.java.opcua.readwrite.BinaryPayload;
import org.apache.plc4x.java.opcua.readwrite.ChunkType;
import org.apache.plc4x.java.opcua.readwrite.MessagePDU;
import org.apache.plc4x.java.opcua.readwrite.MessageSecurityMode;
import org.apache.plc4x.java.opcua.readwrite.OpcuaMessageRequest;
import org.apache.plc4x.java.opcua.readwrite.OpcuaOpenRequest;
import org.apache.plc4x.java.opcua.readwrite.OpcuaProtocolLimits;
Expand All @@ -51,34 +51,62 @@
import org.apache.plc4x.java.opcua.security.MessageSecurity;
import org.apache.plc4x.java.opcua.security.SecurityPolicy;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

class EncryptionHandlerTest {

Supplier<Integer> sequenceSupplier = () -> 0;

CertificateKeyPair clientKeyPair;
CertificateKeyPair serverKeyPair;
static List<Arguments> signKeyLengths() {
return List.of(
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN, 1024, 1024),
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN, 2048, 1024),
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN, 1024, 2048),
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN, 2048, 2048)
);
}

static List<Arguments> encryptKeyLengths() {
return List.of(
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN_ENCRYPT, 1024, 1024),
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN_ENCRYPT, 2048, 1024),
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN_ENCRYPT, 1024, 2048),
Arguments.of(SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN_ENCRYPT, 2048, 2048),
Arguments.of(SecurityPolicy.Basic256Sha256, MessageSecurity.SIGN_ENCRYPT, 1024, 1024),
Arguments.of(SecurityPolicy.Basic256Sha256, MessageSecurity.SIGN_ENCRYPT, 2048, 1024),
Arguments.of(SecurityPolicy.Basic256Sha256, MessageSecurity.SIGN_ENCRYPT, 1024, 2048),
Arguments.of(SecurityPolicy.Basic256Sha256, MessageSecurity.SIGN_ENCRYPT, 2048, 2048)
);
}

@BeforeEach
public void setUp() throws Exception {
Entry<PrivateKey, X509Certificate> clientKeyPair = TestCertificateGenerator.generate(2048, "cn=client", 3600);
Entry<PrivateKey, X509Certificate> serverKeyPair = TestCertificateGenerator.generate(2048, "cn=server", 3600);
private Entry<CertificateKeyPair, CertificateKeyPair> initialize(int client, int server) throws Exception {
Entry<PrivateKey, X509Certificate> clientKeyPair = TestCertificateGenerator.generate(client, "cn=client", 3600);
Entry<PrivateKey, X509Certificate> serverKeyPair = TestCertificateGenerator.generate(server, "cn=server", 3600);

X509Certificate clientCertificate = clientKeyPair.getValue();
PublicKey clientPublicKey = clientCertificate.getPublicKey();
this.clientKeyPair = new CertificateKeyPair(new KeyPair(clientPublicKey, clientKeyPair.getKey()), clientCertificate);

X509Certificate serverCertificate = serverKeyPair.getValue();
PublicKey serverPublicKey = serverCertificate.getPublicKey();
this.serverKeyPair = new CertificateKeyPair(new KeyPair(clientPublicKey, serverKeyPair.getKey()), serverCertificate);

return entry(
new CertificateKeyPair(new KeyPair(clientPublicKey, clientKeyPair.getKey()), clientCertificate),
new CertificateKeyPair(new KeyPair(clientPublicKey, serverKeyPair.getKey()), serverCertificate)
);
}

@Test
void testAsymmetricEncryption() throws Exception {
@ParameterizedTest
@MethodSource("encryptKeyLengths")
void testAsymmetricEncryption(SecurityPolicy securityPolicy, MessageSecurity messageSecurityMode, int client, int server) throws Exception {
Entry<CertificateKeyPair, CertificateKeyPair> keyPairs = initialize(client, server);
CertificateKeyPair clientKeyPair = keyPairs.getKey();
CertificateKeyPair serverKeyPair = keyPairs.getValue();

Conversation conversation = createSecureChannel(clientKeyPair.getCertificate(), serverKeyPair.getCertificate(),
SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN_ENCRYPT, true, true
securityPolicy, messageSecurityMode, true, true
);

EncryptionHandler handler = new EncryptionHandler(conversation, clientKeyPair.getKeyPair().getPrivate());
Expand All @@ -97,7 +125,7 @@ void testAsymmetricEncryption() throws Exception {
OpcuaOpenRequest request = new OpcuaOpenRequest(ChunkType.FINAL,
new OpenChannelMessageRequest(
(int) securityHeader.getSecureChannelId(),
new PascalString(SecurityPolicy.Basic128Rsa15.getSecurityPolicyUri()),
new PascalString(securityPolicy.getSecurityPolicyUri()),
stringFromBytes(clientKeyPair.getCertificate().getEncoded()),
stringFromBytes(DigestUtils.sha1(serverKeyPair.getCertificate().getEncoded()))
),
Expand All @@ -109,11 +137,11 @@ void testAsymmetricEncryption() throws Exception {
assertEquals(1, pdus.size());

// decrypt
conversation = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), SecurityPolicy.Basic128Rsa15,
MessageSecurity.SIGN_ENCRYPT, true, true);
conversation = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), securityPolicy,
messageSecurityMode, true, true);
EncryptionHandler decrypter = new EncryptionHandler(conversation, serverKeyPair.getPrivateKey());
MessagePDU decoded = decrypter.decodeMessage(pdus.get(0));
assertTrue(decoded instanceof OpcuaOpenRequest);
assertInstanceOf(OpcuaOpenRequest.class, decoded);
OpcuaOpenRequest decodedRequest = (OpcuaOpenRequest) decoded;
SequenceHeader decodedSequenceHeader = decodedRequest.getMessage().getSequenceHeader();
Payload decodedPayload = decodedRequest.getMessage();
Expand All @@ -124,10 +152,15 @@ void testAsymmetricEncryption() throws Exception {

}

@Test
void testAsymmetricEncryptionSign() throws Exception {
@ParameterizedTest
@MethodSource("encryptKeyLengths")
void testAsymmetricEncryptionSign(SecurityPolicy securityPolicy, MessageSecurity messageSecurityMode, int client, int server) throws Exception {
Entry<CertificateKeyPair, CertificateKeyPair> keyPairs = initialize(client, server);
CertificateKeyPair clientKeyPair = keyPairs.getKey();
CertificateKeyPair serverKeyPair = keyPairs.getValue();

Conversation secureChannel = createSecureChannel(clientKeyPair.getCertificate(), serverKeyPair.getCertificate(),
SecurityPolicy.Basic128Rsa15, MessageSecurity.SIGN, true, true);
securityPolicy, messageSecurityMode, true, true);

EncryptionHandler handler = new EncryptionHandler(secureChannel, clientKeyPair.getPrivateKey());

Expand All @@ -145,7 +178,7 @@ void testAsymmetricEncryptionSign() throws Exception {
OpcuaOpenRequest request = new OpcuaOpenRequest(ChunkType.FINAL,
new OpenChannelMessageRequest(
(int) securityHeader.getSecureChannelId(),
new PascalString(SecurityPolicy.Basic128Rsa15.getSecurityPolicyUri()),
new PascalString(securityPolicy.getSecurityPolicyUri()),
stringFromBytes(clientKeyPair.getCertificate().getEncoded()),
stringFromBytes(DigestUtils.sha1(serverKeyPair.getCertificate().getEncoded()))
),
Expand All @@ -157,10 +190,11 @@ void testAsymmetricEncryptionSign() throws Exception {
assertEquals(1, pdus.size());

// decrypt
secureChannel = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), SecurityPolicy.Basic128Rsa15,
MessageSecurity.SIGN, true, true);
secureChannel = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), securityPolicy,
messageSecurityMode, true, true);
EncryptionHandler decryptHandler = new EncryptionHandler(secureChannel, serverKeyPair.getPrivateKey());
MessagePDU decoded = decryptHandler.decodeMessage(pdus.get(0));
assertInstanceOf(OpcuaOpenRequest.class, decoded);
OpcuaOpenRequest decodedRequest = (OpcuaOpenRequest) decoded;
SequenceHeader decodedSequenceHeader = decodedRequest.getMessage().getSequenceHeader();
Payload decodedPayload = decodedRequest.getMessage();
Expand All @@ -171,10 +205,15 @@ void testAsymmetricEncryptionSign() throws Exception {

}

@Test
void testSymmetricEncryption() throws Exception {
Conversation secureChannel = createSecureChannel(clientKeyPair.getCertificate(), serverKeyPair.getCertificate(), SecurityPolicy.Basic128Rsa15,
MessageSecurity.SIGN_ENCRYPT, true, true);
@ParameterizedTest
@MethodSource("signKeyLengths")
void testSymmetricEncryption(SecurityPolicy securityPolicy, MessageSecurity messageSecurityMode, int client, int server) throws Exception {
Entry<CertificateKeyPair, CertificateKeyPair> keyPairs = initialize(client, server);
CertificateKeyPair clientKeyPair = keyPairs.getKey();
CertificateKeyPair serverKeyPair = keyPairs.getValue();

Conversation secureChannel = createSecureChannel(clientKeyPair.getCertificate(), serverKeyPair.getCertificate(), securityPolicy,
messageSecurityMode, true, true);

EncryptionHandler handler = new EncryptionHandler(secureChannel, clientKeyPair.getPrivateKey());

Expand All @@ -199,8 +238,8 @@ void testSymmetricEncryption() throws Exception {
assertEquals(1, pdus.size());

// decrypt
secureChannel = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), SecurityPolicy.Basic128Rsa15,
MessageSecurity.SIGN, true, true);
secureChannel = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), securityPolicy,
messageSecurityMode, true, true);
EncryptionHandler decryptHandler = new EncryptionHandler(secureChannel, serverKeyPair.getPrivateKey());
MessagePDU decoded = decryptHandler.decodeMessage(pdus.get(0));
OpcuaMessageRequest decodedRequest = (OpcuaMessageRequest) decoded;
Expand All @@ -212,10 +251,15 @@ void testSymmetricEncryption() throws Exception {
}
}

@Test
void testSymmetricEncryptionSign() throws Exception {
Conversation secureChannel = createSecureChannel(clientKeyPair.getCertificate(), serverKeyPair.getCertificate(), SecurityPolicy.Basic128Rsa15,
MessageSecurity.SIGN, true, true);
@ParameterizedTest
@MethodSource("signKeyLengths")
void testSymmetricEncryptionSign(SecurityPolicy securityPolicy, MessageSecurity messageSecurityMode, int client, int server) throws Exception {
Entry<CertificateKeyPair, CertificateKeyPair> keyPairs = initialize(client, server);
CertificateKeyPair clientKeyPair = keyPairs.getKey();
CertificateKeyPair serverKeyPair = keyPairs.getValue();

Conversation secureChannel = createSecureChannel(clientKeyPair.getCertificate(), serverKeyPair.getCertificate(), securityPolicy,
messageSecurityMode, true, true);

EncryptionHandler handler = new EncryptionHandler(secureChannel, clientKeyPair.getPrivateKey());

Expand All @@ -240,8 +284,8 @@ void testSymmetricEncryptionSign() throws Exception {
assertEquals(1, pdus.size());

// decrypt
secureChannel = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), SecurityPolicy.Basic128Rsa15,
MessageSecurity.SIGN, true, true);
secureChannel = createSecureChannel(serverKeyPair.getCertificate(), clientKeyPair.getCertificate(), securityPolicy,
messageSecurityMode, true, true);
EncryptionHandler decryptHandler = new EncryptionHandler(secureChannel, serverKeyPair.getPrivateKey());
MessagePDU decoded = decryptHandler.decodeMessage(pdus.get(0));
OpcuaMessageRequest decodedRequest = (OpcuaMessageRequest) decoded;
Expand Down

0 comments on commit cb8153a

Please sign in to comment.