Skip to content

Commit

Permalink
GH-524: More ChaCha20-Poly1305 optimizations
Browse files Browse the repository at this point in the history
In the ChaChaEngine, exploit peculiarities of its use in SSH: the
counter is actually 32bit and never overflows, and the nonce is the
SSH packet sequence number, also 32 bits, and wraps on overflow. As
a result two of the ints of the engine state are always zero, and
the handling of nonce and counter can be slightly simplified.

In Poly1305Mac, inline the long multiplications. Avoid extensions
from int to long for the precomputed values (r, s, k); store them as
longs up front. For h, reduce the number of extensions from 25 to 5
by doing it once before the multiplications.

As a side effect this part of the code also is nicer to read.
  • Loading branch information
tomaswolf committed Jul 29, 2024
1 parent 4b6f809 commit d53fea9
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import org.apache.sshd.common.mac.Mac;
import org.apache.sshd.common.mac.Poly1305Mac;
import org.apache.sshd.common.util.NumberUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.BufferUtils;

Expand All @@ -38,11 +37,11 @@
public class ChaCha20Cipher implements Cipher {
protected final ChaChaEngine headerEngine = new ChaChaEngine();
protected final ChaChaEngine bodyEngine = new ChaChaEngine();
protected final Mac mac = new Poly1305Mac();
protected final Mac mac;
protected Mode mode;

public ChaCha20Cipher() {
// empty
this.mac = new Poly1305Mac();
}

@Override
Expand Down Expand Up @@ -131,7 +130,7 @@ public int getKdfSize() {

@Override
public int getKeySize() {
return 256;
return 512;
}

protected static class ChaChaEngine {
Expand All @@ -142,15 +141,24 @@ protected static class ChaChaEngine {
private static final int KEY_INTS = KEY_BYTES / Integer.BYTES;
private static final int COUNTER_OFFSET = 12;
private static final int NONCE_OFFSET = 14;
private static final int NONCE_BYTES = 8;
private static final int NONCE_INTS = NONCE_BYTES / Integer.BYTES;
private static final int[] ENGINE_STATE_HEADER = unpackSigmaString(
"expand 32-byte k".getBytes(StandardCharsets.US_ASCII));

protected final int[] engineState = new int[BLOCK_INTS];
protected final byte[] keyStream = new byte[BLOCK_BYTES];
protected final byte[] nonce = new byte[NONCE_BYTES];
protected final byte[] nonce = new byte[Integer.BYTES];
protected long initialNonce;
protected long nonceVal;

// Elements 12 to 15 in the engineState are the counter and the nonce. The counter is a 64bit little-
// endian value; the nonce is a 64bit big-endian value.
//
// The counter always starts at zero, is incremented with each full block (64 bytes), and in SSH never
// overflows 32bits because it counts only inside a single SSH packet. The nonce in SSH is the packet
// sequence number, which is a 32bit unsigned int that wraps around on overflow.
//
// Therefore, engineState[13] and engineState[14] are always zero. engineState[12] is the counter, and
// engineState[15] is the packet sequence number in inverse byte order.

protected ChaChaEngine() {
System.arraycopy(ENGINE_STATE_HEADER, 0, engineState, 0, 4);
Expand All @@ -161,21 +169,25 @@ protected void initKey(byte[] key) {
}

protected void initNonce(byte[] nonce) {
initialNonce = BufferUtils.getLong(nonce, 0, NumberUtils.length(nonce));
unpackIntsLE(nonce, 0, NONCE_INTS, engineState, NONCE_OFFSET);
System.arraycopy(nonce, 0, this.nonce, 0, NONCE_BYTES);
long hiBits = BufferUtils.getUInt(nonce, 0, Integer.BYTES);
ValidateUtils.checkState(hiBits == 0, "ChaCha20 nonce is not a valid SSH packet sequence number");
initialNonce = BufferUtils.getUInt(nonce, Integer.BYTES, Integer.BYTES);
nonceVal = initialNonce;
engineState[NONCE_OFFSET] = 0;
engineState[NONCE_OFFSET + 1] = Poly1305Mac.unpackIntLE(nonce, Integer.BYTES);
}

protected void advanceNonce() {
long counter = BufferUtils.getLong(nonce, 0, NONCE_BYTES) + 1;
ValidateUtils.checkState(counter != initialNonce, "Packet sequence number cannot be reused with the same key");
BufferUtils.putLong(counter, nonce, 0, NONCE_BYTES);
unpackIntsLE(nonce, 0, NONCE_INTS, engineState, NONCE_OFFSET);
// SSH packet sequence number wraps around on uint32 overflow.
nonceVal = (nonceVal + 1) & 0xFFFF_FFFFL;
ValidateUtils.checkState(nonceVal != initialNonce, "Packet sequence number cannot be reused with the same key");
BufferUtils.putUInt(nonceVal, nonce, 0, Integer.BYTES);
engineState[NONCE_OFFSET + 1] = Poly1305Mac.unpackIntLE(nonce, 0);
}

protected void initCounter(long counter) {
engineState[COUNTER_OFFSET] = (int) counter;
engineState[COUNTER_OFFSET + 1] = (int) (counter >>> Integer.SIZE);
engineState[COUNTER_OFFSET + 1] = 0; // Always zero; and counter never overflows in SSH.
}

// one-shot usage
Expand All @@ -187,11 +199,7 @@ protected void crypt(byte[] in, int offset, int length, byte[] out, int outOffse
out[outOffset++] = (byte) (in[offset++] ^ keyStream[i]);
}
length -= want;
int lo = ++engineState[COUNTER_OFFSET];
if (lo == 0) {
// overflow
++engineState[COUNTER_OFFSET + 1];
}
++engineState[COUNTER_OFFSET]; // Never overflows in SSH
}
}

Expand All @@ -216,10 +224,10 @@ protected void setKeyStream(int[] engine) {
int x9 = engine[9];
int x10 = engine[10];
int x11 = engine[11];
int x12 = engine[12];
int x13 = engine[13];
int x14 = engine[14];
int x15 = engine[15];
int x12 = engine[12]; // counter
int x13 = engine[13]; // 0
int x14 = engine[14]; // 0
int x15 = engine[15]; // nonce

for (int i = 0; i < 10; i++) {
// Columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ public class Poly1305Mac implements Mac {
public static final int KEY_BYTES = 32;
private static final int BLOCK_SIZE = 16;

private int r0;
private int r1;
private int r2;
private int r3;
private int r4;
private int s1;
private int s2;
private int s3;
private int s4;
private int k0;
private int k1;
private int k2;
private int k3;
private long r0;
private long r1;
private long r2;
private long r3;
private long r4;
private long s1;
private long s2;
private long s3;
private long s4;
private long k0;
private long k1;
private long k2;
private long k3;

private int h0;
private int h1;
Expand Down Expand Up @@ -91,10 +91,10 @@ public void init(byte[] key) throws Exception {
s3 = r3 * 5;
s4 = r4 * 5;

k0 = unpackIntLE(key, 16);
k1 = unpackIntLE(key, 20);
k2 = unpackIntLE(key, 24);
k3 = unpackIntLE(key, 28);
k0 = unpackIntLE(key, 16) & 0xFFFF_FFFFL;
k1 = unpackIntLE(key, 20) & 0xFFFF_FFFFL;
k2 = unpackIntLE(key, 24) & 0xFFFF_FFFFL;
k3 = unpackIntLE(key, 28) & 0xFFFF_FFFFL;

currentBlockOffset = 0;
}
Expand Down Expand Up @@ -186,10 +186,10 @@ public void doFinal(byte[] out, int offset) throws Exception {
h3 = h3 & nb | g3 & b;
h4 = h4 & nb | g4 & b;

long f0 = ((h0 | h1 << 26) & 0xFFFF_FFFFL) + (k0 & 0xFFFF_FFFFL);
long f1 = ((h1 >>> 6 | h2 << 20) & 0xFFFF_FFFFL) + (k1 & 0xFFFF_FFFFL);
long f2 = ((h2 >>> 12 | h3 << 14) & 0xFFFF_FFFFL) + (k2 & 0xFFFF_FFFFL);
long f3 = ((h3 >>> 18 | h4 << 8) & 0xFFFF_FFFFL) + (k3 & 0xFFFF_FFFFL);
long f0 = ((h0 | h1 << 26) & 0xFFFF_FFFFL) + k0;
long f1 = ((h1 >>> 6 | h2 << 20) & 0xFFFF_FFFFL) + k1;
long f2 = ((h2 >>> 12 | h3 << 14) & 0xFFFF_FFFFL) + k2;
long f3 = ((h3 >>> 18 | h4 << 8) & 0xFFFF_FFFFL) + k3;

packIntLE((int) f0, out, offset);
f1 += f0 >>> 32;
Expand Down Expand Up @@ -219,16 +219,18 @@ private void processBlock(byte[] block, int offset, int length) {
h4 += 1 << 24;
}

long tp0 = unsignedProduct(h0, r0) + unsignedProduct(h1, s4) + unsignedProduct(h2, s3) + unsignedProduct(h3, s2)
+ unsignedProduct(h4, s1);
long tp1 = unsignedProduct(h0, r1) + unsignedProduct(h1, r0) + unsignedProduct(h2, s4) + unsignedProduct(h3, s3)
+ unsignedProduct(h4, s2);
long tp2 = unsignedProduct(h0, r2) + unsignedProduct(h1, r1) + unsignedProduct(h2, r0) + unsignedProduct(h3, s4)
+ unsignedProduct(h4, s3);
long tp3 = unsignedProduct(h0, r3) + unsignedProduct(h1, r2) + unsignedProduct(h2, r1) + unsignedProduct(h3, r0)
+ unsignedProduct(h4, s4);
long tp4 = unsignedProduct(h0, r4) + unsignedProduct(h1, r3) + unsignedProduct(h2, r2) + unsignedProduct(h3, r1)
+ unsignedProduct(h4, r0);
// The high bits of h0 to h4 are guaranteed to be zero, so we can just let the compiler extend the ints.
// No need to do a & 0xFFFF_FFFFL.
long l0 = h0;
long l1 = h1;
long l2 = h2;
long l3 = h3;
long l4 = h4;
long tp0 = l0 * r0 + l1 * s4 + l2 * s3 + l3 * s2 + l4 * s1;
long tp1 = l0 * r1 + l1 * r0 + l2 * s4 + l3 * s3 + l4 * s2;
long tp2 = l0 * r2 + l1 * r1 + l2 * r0 + l3 * s4 + l4 * s3;
long tp3 = l0 * r3 + l1 * r2 + l2 * r1 + l3 * r0 + l4 * s4;
long tp4 = l0 * r4 + l1 * r3 + l2 * r2 + l3 * r1 + l4 * r0;

h0 = (int) tp0 & 0x3ffffff;
tp1 += tp0 >>> 26;
Expand Down Expand Up @@ -278,8 +280,4 @@ public static void packIntLE(int value, byte[] dst, int off) {
dst[off++] = (byte) (value >> 16);
dst[off] = (byte) (value >> 24);
}

private static long unsignedProduct(int i1, int i2) {
return (i1 & 0xFFFF_FFFFL) * i2;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ private static void testCipherEncryption(Random rnd, Cipher cipher) throws Excep
byte[] key = new byte[cipher.getKdfSize()];
rnd.nextBytes(key);
byte[] iv = new byte[cipher.getIVSize()];
rnd.nextBytes(iv);
// ChaCha20 has an SSH packet sequence number as IV! Do not use random IVs with ChaCha20!
if (cipher.getAlgorithm().startsWith("ChaCha20")) {
iv[iv.length - 1] = 42;
} else {
rnd.nextBytes(iv);
}
cipher.init(Cipher.Mode.Encrypt, key, iv);

byte[] data = new byte[cipher.getCipherBlockSize() + cipher.getAuthenticationTagSize()];
Expand Down

0 comments on commit d53fea9

Please sign in to comment.