Skip to content

Commit

Permalink
src: move CipherCtx methods to ncrypto
Browse files Browse the repository at this point in the history
  • Loading branch information
jasnell committed Jan 7, 2025
1 parent 3dacb82 commit f22a969
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 163 deletions.
151 changes: 130 additions & 21 deletions deps/ncrypto/ncrypto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ int BignumPointer::isPrime(int nchecks,
return -1;
BN_GENCB_set(
cb.get(),
// TODO(@jasnell): This could be refactored to allow inlining.
// Not too important right now tho.
[](int a, int b, BN_GENCB* ctx) mutable -> int {
PrimeCheckCallback& ptr =
*static_cast<PrimeCheckCallback*>(BN_GENCB_get_arg(ctx));
Expand Down Expand Up @@ -374,6 +376,7 @@ bool BignumPointer::generate(const PrimeConfig& params,
BignumPointer BignumPointer::NewSub(const BignumPointer& a,
const BignumPointer& b) {
BignumPointer res = New();
if (!res) return {};
if (!BN_sub(res.get(), a.get(), b.get())) {
return {};
}
Expand All @@ -382,6 +385,7 @@ BignumPointer BignumPointer::NewSub(const BignumPointer& a,

BignumPointer BignumPointer::NewLShift(size_t length) {
BignumPointer res = New();
if (!res) return {};
if (!BN_lshift(res.get(), One(), length)) {
return {};
}
Expand Down Expand Up @@ -1192,8 +1196,8 @@ std::string_view X509Pointer::ErrorCode(int32_t err) { // NOLINT(runtime/int)
return "UNSPECIFIED";
}

std::string_view X509Pointer::ErrorReason(int32_t err) {
if (err == X509_V_OK) return "";
std::optional<std::string_view> X509Pointer::ErrorReason(int32_t err) {
if (err == X509_V_OK) return std::nullopt;
return X509_verify_cert_error_string(err);
}

Expand Down Expand Up @@ -2235,11 +2239,12 @@ void SSLPointer::getCiphers(
// document them, but since there are only 5, easier to just add them manually
// and not have to explain their absence in the API docs. They are lower-cased
// because the docs say they will be.
static const char* TLS13_CIPHERS[] = {"tls_aes_256_gcm_sha384",
"tls_chacha20_poly1305_sha256",
"tls_aes_128_gcm_sha256",
"tls_aes_128_ccm_8_sha256",
"tls_aes_128_ccm_sha256"};
static constexpr const char* TLS13_CIPHERS[] = {
"tls_aes_256_gcm_sha384",
"tls_chacha20_poly1305_sha256",
"tls_aes_128_gcm_sha256",
"tls_aes_128_ccm_8_sha256",
"tls_aes_128_ccm_sha256"};

const int n = sk_SSL_CIPHER_num(ciphers);

Expand All @@ -2249,8 +2254,7 @@ void SSLPointer::getCiphers(
}

for (unsigned i = 0; i < 5; ++i) {
const char* name = TLS13_CIPHERS[i];
cb(name);
cb(TLS13_CIPHERS[i]);
}
}

Expand All @@ -2265,7 +2269,6 @@ bool SSLPointer::setSniContext(const SSLCtxPointer& ctx) const {
if (!x509) return false;
EVP_PKEY* pkey = SSL_CTX_get0_privatekey(ctx.get());
STACK_OF(X509) * chain;

int err = SSL_CTX_get0_chain_certs(ctx.get(), &chain);
if (err == 1) err = SSL_use_certificate(get(), x509);
if (err == 1) err = SSL_use_PrivateKey(get(), pkey);
Expand Down Expand Up @@ -2294,7 +2297,7 @@ std::optional<uint32_t> SSLPointer::verifyPeerCertificate() const {
}

const std::string_view SSLPointer::getClientHelloAlpn() const {
if (ssl_ == nullptr) return std::string_view();
if (ssl_ == nullptr) return {};
const unsigned char* buf;
size_t len;
size_t rem;
Expand All @@ -2305,34 +2308,34 @@ const std::string_view SSLPointer::getClientHelloAlpn() const {
&buf,
&rem) ||
rem < 2) {
return nullptr;
return {};
}

len = (buf[0] << 8) | buf[1];
if (len + 2 != rem) return nullptr;
if (len + 2 != rem) return {};
return reinterpret_cast<const char*>(buf + 3);
}

const std::string_view SSLPointer::getClientHelloServerName() const {
if (ssl_ == nullptr) return std::string_view();
if (ssl_ == nullptr) return {};
const unsigned char* buf;
size_t len;
size_t rem;

if (!SSL_client_hello_get0_ext(get(), TLSEXT_TYPE_server_name, &buf, &rem) ||
rem <= 2) {
return nullptr;
return {};
}

len = (*buf << 8) | *(buf + 1);
if (len + 2 != rem) return nullptr;
if (len + 2 != rem) return {};
rem = len;

if (rem == 0 || *(buf + 2) != TLSEXT_NAMETYPE_host_name) return nullptr;
if (rem == 0 || *(buf + 2) != TLSEXT_NAMETYPE_host_name) return {};
rem--;
if (rem <= 2) return nullptr;
if (rem <= 2) return {};
len = (*(buf + 3) << 8) | *(buf + 4);
if (len + 2 > rem) return nullptr;
if (len + 2 > rem) return {};
return reinterpret_cast<const char*>(buf + 5);
}

Expand Down Expand Up @@ -2453,7 +2456,7 @@ int Cipher::getNid() const {
return EVP_CIPHER_nid(cipher_);
}

const std::string_view Cipher::getModeLabel() const {
std::string_view Cipher::getModeLabel() const {
if (!cipher_) return {};
switch (getMode()) {
case EVP_CIPH_CCM_MODE:
Expand Down Expand Up @@ -2482,7 +2485,7 @@ const std::string_view Cipher::getModeLabel() const {
return "{unknown}";
}

const std::string_view Cipher::getName() const {
std::string_view Cipher::getName() const {
if (!cipher_) return {};
// OBJ_nid2sn(EVP_CIPHER_nid(cipher)) is used here instead of
// EVP_CIPHER_name(cipher) for compatibility with BoringSSL.
Expand All @@ -2504,4 +2507,110 @@ bool Cipher::isSupportedAuthenticatedMode() const {
}
}

// ============================================================================

CipherCtxPointer CipherCtxPointer::New() {
auto ret = CipherCtxPointer(EVP_CIPHER_CTX_new());
if (!ret) return {};
EVP_CIPHER_CTX_init(ret.get());
return ret;
}

CipherCtxPointer::CipherCtxPointer(EVP_CIPHER_CTX* ctx) : ctx_(ctx) {}

CipherCtxPointer::CipherCtxPointer(CipherCtxPointer&& other) noexcept
: ctx_(other.release()) {}

CipherCtxPointer& CipherCtxPointer::operator=(
CipherCtxPointer&& other) noexcept {
if (this == &other) return *this;
this->~CipherCtxPointer();
return *new (this) CipherCtxPointer(std::move(other));
}

CipherCtxPointer::~CipherCtxPointer() {
reset();
}

void CipherCtxPointer::reset(EVP_CIPHER_CTX* ctx) {
ctx_.reset(ctx);
}

EVP_CIPHER_CTX* CipherCtxPointer::release() {
return ctx_.release();
}

void CipherCtxPointer::setFlags(int flags) {
if (!ctx_) return;
EVP_CIPHER_CTX_set_flags(ctx_.get(), flags);
}

bool CipherCtxPointer::setKeyLength(size_t length) {
if (!ctx_) return false;
return EVP_CIPHER_CTX_set_key_length(ctx_.get(), length);
}

bool CipherCtxPointer::setIvLength(size_t length) {
if (!ctx_) return false;
return EVP_CIPHER_CTX_ctrl(
ctx_.get(), EVP_CTRL_AEAD_SET_IVLEN, length, nullptr);
}

bool CipherCtxPointer::setAeadTag(const Buffer<const char>& tag) {
if (!ctx_) return false;
return EVP_CIPHER_CTX_ctrl(
ctx_.get(), EVP_CTRL_AEAD_SET_TAG, tag.len, const_cast<char*>(tag.data));
}

bool CipherCtxPointer::setAeadTagLength(size_t length) {
if (!ctx_) return false;
return EVP_CIPHER_CTX_ctrl(
ctx_.get(), EVP_CTRL_AEAD_SET_TAG, length, nullptr);
}

bool CipherCtxPointer::setPadding(bool padding) {
if (!ctx_) return false;
return EVP_CIPHER_CTX_set_padding(ctx_.get(), padding);
}

int CipherCtxPointer::getBlockSize() const {
if (!ctx_) return 0;
return EVP_CIPHER_CTX_block_size(ctx_.get());
}

int CipherCtxPointer::getMode() const {
if (!ctx_) return 0;
return EVP_CIPHER_CTX_mode(ctx_.get());
}

int CipherCtxPointer::getNid() const {
if (!ctx_) return 0;
return EVP_CIPHER_CTX_nid(ctx_.get());
}

bool CipherCtxPointer::init(const Cipher& cipher,
bool encrypt,
const unsigned char* key,
const unsigned char* iv) {
if (!ctx_) return false;
return EVP_CipherInit_ex(
ctx_.get(), cipher, nullptr, key, iv, encrypt ? 1 : 0) == 1;
}

bool CipherCtxPointer::update(const Buffer<const unsigned char>& in,
unsigned char* out,
int* out_len,
bool finalize) {
if (!ctx_) return false;
if (!finalize) {
return EVP_CipherUpdate(ctx_.get(), out, out_len, in.data, in.len) == 1;
}
return EVP_CipherFinal_ex(ctx_.get(), out, out_len) == 1;
}

bool CipherCtxPointer::getAeadTag(size_t len, unsigned char* out) {
if (!ctx_) return false;
return EVP_CIPHER_CTX_ctrl(ctx_.get(), EVP_CTRL_AEAD_GET_TAG, len, out);
}

} // namespace ncrypto
54 changes: 50 additions & 4 deletions deps/ncrypto/ncrypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ using DeleteFnPtr = typename FunctionDeleter<T, function>::Pointer;

using BignumCtxPointer = DeleteFnPtr<BN_CTX, BN_CTX_free>;
using BignumGenCallbackPointer = DeleteFnPtr<BN_GENCB, BN_GENCB_free>;
using CipherCtxPointer = DeleteFnPtr<EVP_CIPHER_CTX, EVP_CIPHER_CTX_free>;
using DSAPointer = DeleteFnPtr<DSA, DSA_free>;
using DSASigPointer = DeleteFnPtr<DSA_SIG, DSA_SIG_free>;
using ECDSASigPointer = DeleteFnPtr<ECDSA_SIG, ECDSA_SIG_free>;
Expand All @@ -213,6 +212,8 @@ using PKCS8Pointer = DeleteFnPtr<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_free>;
using RSAPointer = DeleteFnPtr<RSA, RSA_free>;
using SSLSessionPointer = DeleteFnPtr<SSL_SESSION, SSL_SESSION_free>;

class CipherCtxPointer;

struct StackOfXASN1Deleter {
void operator()(STACK_OF(ASN1_OBJECT) * p) const {
sk_ASN1_OBJECT_pop_free(p, ASN1_OBJECT_free);
Expand Down Expand Up @@ -248,8 +249,8 @@ class Cipher final {
int getIvLength() const;
int getKeyLength() const;
int getBlockSize() const;
const std::string_view getModeLabel() const;
const std::string_view getName() const;
std::string_view getModeLabel() const;
std::string_view getName() const;

bool isSupportedAuthenticatedMode() const;

Expand Down Expand Up @@ -425,6 +426,51 @@ class BignumPointer final {
static bool defaultPrimeCheckCallback(int, int) { return 1; }
};

class CipherCtxPointer final {
public:
static CipherCtxPointer New();

CipherCtxPointer() = default;
explicit CipherCtxPointer(EVP_CIPHER_CTX* ctx);
CipherCtxPointer(CipherCtxPointer&& other) noexcept;
CipherCtxPointer& operator=(CipherCtxPointer&& other) noexcept;
NCRYPTO_DISALLOW_COPY(CipherCtxPointer)
~CipherCtxPointer();

inline bool operator==(std::nullptr_t) const noexcept {
return ctx_ == nullptr;
}
inline operator bool() const { return ctx_ != nullptr; }
inline EVP_CIPHER_CTX* get() const { return ctx_.get(); }
inline operator EVP_CIPHER_CTX*() const { return ctx_.get(); }
void reset(EVP_CIPHER_CTX* ctx = nullptr);
EVP_CIPHER_CTX* release();

void setFlags(int flags);
bool setKeyLength(size_t length);
bool setIvLength(size_t length);
bool setAeadTag(const Buffer<const char>& tag);
bool setAeadTagLength(size_t length);
bool setPadding(bool padding);
bool init(const Cipher& cipher,
bool encrypt,
const unsigned char* key = nullptr,
const unsigned char* iv = nullptr);

int getBlockSize() const;
int getMode() const;
int getNid() const;

bool update(const Buffer<const unsigned char>& in,
unsigned char* out,
int* out_len,
bool finalize = false);
bool getAeadTag(size_t len, unsigned char* out);

private:
DeleteFnPtr<EVP_CIPHER_CTX, EVP_CIPHER_CTX_free> ctx_;
};

class EVPKeyPointer final {
public:
static EVPKeyPointer New();
Expand Down Expand Up @@ -772,7 +818,7 @@ class X509Pointer final {
operator X509View() const { return view(); }

static std::string_view ErrorCode(int32_t err);
static std::string_view ErrorReason(int32_t err);
static std::optional<std::string_view> ErrorReason(int32_t err);

private:
DeleteFnPtr<X509, X509_free> cert_;
Expand Down
Loading

0 comments on commit f22a969

Please sign in to comment.