Skip to content

Commit

Permalink
optimized AES decryption's memory allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
mrdcvlsc committed Jul 16, 2023
1 parent 828e313 commit 81b9545
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 62 deletions.
115 changes: 54 additions & 61 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#define CRYPTOLIB "portable"
#endif

#define BETHELA_VERSION "version 3.6.1"
#define BETHELA_VERSION "version 3.7.0"

#define HELP_FLAG "--help"
#define VERSION_FLAG "--version"
Expand Down Expand Up @@ -426,10 +426,7 @@ int main(int argc, char *args[]) {

bool run_thread = true;

char *next_buffer = new char[BUFFER_BYTESIZE];
char *prev_buffer = new char[BUFFER_BYTESIZE];
Krypt::Bytes *decryptedBuffer = new Krypt::Bytes[BUFFER_BYTESIZE];

char *read_buffer = new char[BUFFER_BYTESIZE];
char *filesig = new char[bconst::FILESIGNATURE.size()];

while (run_thread) {
Expand Down Expand Up @@ -461,6 +458,7 @@ int main(int argc, char *args[]) {
"it might be read protected, corrupted or non-existent...\n";
output_mtx.unlock();
} else {

output_mtx.lock();
std::cout << "decrypting : " << target_file << "...\n";
output_mtx.unlock();
Expand All @@ -478,90 +476,85 @@ int main(int argc, char *args[]) {
output_file.close();
output_file.open(outfname, std::ios::binary | std::ios::app);

unsigned char *iv = new unsigned char[AES_BLOCKSIZE];
unsigned char iv[AES_BLOCKSIZE];
curr_file.read(reinterpret_cast<char *>(iv), AES_BLOCKSIZE);

char *swap_buffer_ptr;
// new decryption

curr_file.read(prev_buffer, BUFFER_BYTESIZE);
size_t prev_buffer_size = curr_file.gcount();
size_t next_buffer_size = 0;
constexpr size_t BUFFER_SIZE_NOLAST = BUFFER_BYTESIZE - AES_BLOCKSIZE;
char decrypted_block_holder[AES_BLOCKSIZE] = {};
char last_block[AES_BLOCKSIZE] = {};

while (!curr_file.eof()) {
curr_file.read(next_buffer, BUFFER_BYTESIZE);
next_buffer_size = curr_file.gcount();
bool lastblock_remains = false;

if (!next_buffer_size) {
break;
}
while (!curr_file.eof()) {
curr_file.read(read_buffer, BUFFER_BYTESIZE);
size_t read_buffer_size = curr_file.gcount();

for (size_t index = 0; index < prev_buffer_size; index += AES_BLOCKSIZE) {
aes_scheme.blockDecrypt(
reinterpret_cast<unsigned char *>(prev_buffer + index),
reinterpret_cast<unsigned char *>(decryptedBuffer + index), iv
);
if (lastblock_remains) {
if (read_buffer_size) {
aes_scheme.blockDecrypt(
reinterpret_cast<unsigned char *>(last_block),
reinterpret_cast<unsigned char *>(decrypted_block_holder), iv
);
output_file.write(reinterpret_cast<char *>(decrypted_block_holder), AES_BLOCKSIZE);
lastblock_remains = false;
} else {
break;
}
}

output_file.write(reinterpret_cast<char *>(decryptedBuffer), prev_buffer_size);

swap_buffer_ptr = next_buffer;
next_buffer = prev_buffer;
prev_buffer = swap_buffer_ptr;
swap_buffer_ptr = nullptr;
std::swap(next_buffer_size, prev_buffer_size);
}

size_t remaining_blocks = prev_buffer_size / AES_BLOCKSIZE;
size_t remaining_bytes = prev_buffer_size % AES_BLOCKSIZE;
size_t index = 0;
if (read_buffer_size == BUFFER_BYTESIZE) {
size_t index;
for (index = 0; index < BUFFER_SIZE_NOLAST; index += AES_BLOCKSIZE) {
aes_scheme.blockDecrypt(
reinterpret_cast<unsigned char *>(read_buffer + index),
reinterpret_cast<unsigned char *>(decrypted_block_holder), iv
);
std::memcpy(read_buffer + index, decrypted_block_holder, AES_BLOCKSIZE);
}

bool excludeLastBlock = (remaining_blocks && remaining_bytes == 0);
output_file.write(reinterpret_cast<char *>(read_buffer), BUFFER_SIZE_NOLAST);
std::memcpy(last_block, read_buffer + index, AES_BLOCKSIZE);
} else if (read_buffer_size % AES_BLOCKSIZE == 0) {
size_t index;
for (index = 0; index < read_buffer_size - AES_BLOCKSIZE; index += AES_BLOCKSIZE) {
aes_scheme.blockDecrypt(
reinterpret_cast<unsigned char *>(read_buffer + index),
reinterpret_cast<unsigned char *>(decrypted_block_holder), iv
);
std::memcpy(read_buffer + index, decrypted_block_holder, AES_BLOCKSIZE);
}

if (remaining_blocks) {
for (index = 0; index < remaining_blocks - excludeLastBlock; ++index) {
aes_scheme.blockDecrypt(
reinterpret_cast<unsigned char *>(prev_buffer + (index * AES_BLOCKSIZE)),
reinterpret_cast<unsigned char *>(decryptedBuffer + (index * AES_BLOCKSIZE)), iv
);
output_file.write(reinterpret_cast<char *>(read_buffer), read_buffer_size - AES_BLOCKSIZE);
std::memcpy(last_block, read_buffer + index, AES_BLOCKSIZE);
}

output_file.write(
reinterpret_cast<char *>(decryptedBuffer), (remaining_blocks - excludeLastBlock) * AES_BLOCKSIZE
);
lastblock_remains = true;
}

Krypt::ByteArray recover;
if (lastblock_remains) {
Krypt::ByteArray recover =
aes_scheme.decrypt(reinterpret_cast<unsigned char *>(last_block), AES_BLOCKSIZE, iv);

if (excludeLastBlock) {
recover = aes_scheme.decrypt(
reinterpret_cast<unsigned char *>(prev_buffer + (index * AES_BLOCKSIZE)), AES_BLOCKSIZE, iv
);
} else {
recover = aes_scheme.decrypt(
reinterpret_cast<unsigned char *>(prev_buffer + (index * AES_BLOCKSIZE)), remaining_bytes, iv
);
output_file.write(reinterpret_cast<char *>(recover.array), recover.length);
lastblock_remains = false;
}

output_file.write(reinterpret_cast<char *>(recover.array), recover.length);
// new decryption

cnt++;

std::memset((unsigned char *) iv, 0x00, AES_BLOCKSIZE);
delete[] iv;

checkif_replace(args[COMMAND], target_file);
}
}

std::memset((char *) next_buffer, 0x00, BUFFER_BYTESIZE);
std::memset((char *) prev_buffer, 0x00, BUFFER_BYTESIZE);
std::memset((char *) read_buffer, 0x00, BUFFER_BYTESIZE);
std::memset((char *) filesig, 0x00, bconst::FILESIGNATURE.size());
std::memset((Krypt::Bytes *) decryptedBuffer, 0x00, BUFFER_BYTESIZE);

delete[] next_buffer;
delete[] prev_buffer;
delete[] read_buffer;
delete[] filesig;
delete[] decryptedBuffer;
};

std::vector<std::thread> threads;
Expand Down
13 changes: 12 additions & 1 deletion tests/FileCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,18 @@ bool ISEQUAL(const std::string &FileA, const std::string &FileB) {
bconst::bytestream B = byteio::file_read(FileB);

if (A.size() != B.size()) {
std::cout << "\n\t\tSize is not equal\n";
std::cout << "\n\t\tSize is not equal (" << A.size() << ", " << B.size() << ")";

size_t not_equal = 0;
for (size_t i = 0; i < A.size(); ++i) {
if (A[i] != B[i]) {
not_equal++;
}
}

float equality = ((float) A.size() - (float) not_equal) / (float) A.size();
std::cout << "\n\t\tEqual by : " << equality * 100.0f << " percent\n";

return false;
} else if (A.size() <= 0 || B.size() <= 0) {
return false;
Expand Down

0 comments on commit 81b9545

Please sign in to comment.