diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b40a9c..dd4ab9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,14 +1,52 @@ -cmake_minimum_required (VERSION 3.8) +cmake_minimum_required (VERSION 3.9.2) -project ("EGL2") +project (EGL2) + +set(WITH_GUI ON CACHE BOOL "Compile with a GUI" FORCE) +set(WX_DIR "J:\\Code\\wxWidgets" CACHE PATH "wxWidgets directory" FORCE) +message(${WX_DIR}) aux_source_directory(. FILE_SOURCES) -aux_source_directory(memfs MEMFS_FILE_SOURCES) +aux_source_directory(filesystem FILESYSTEM_FILE_SOURCES) +aux_source_directory(storage STORAGE_FILE_SOURCES) +aux_source_directory(web WEB_FILE_SOURCES) + +if (WITH_GUI) + message("Building with GUI") + + aux_source_directory(gui INTERFACE_FILE_SOURCES) + add_executable(EGL2 WIN32 ${INTERFACE_FILE_SOURCES} "gui/resources.rc" ${FILESYSTEM_FILE_SOURCES} ${STORAGE_FILE_SOURCES} ${WEB_FILE_SOURCES} ${FILE_SOURCES}) + + set(wxWidgets_ROOT_DIR "${WX_DIR}") + set(wxWidgets_LIB_DIR "${WX_DIR}/lib/vc_x64_lib") + set(wxWidgets_EXCLUDE_COMMON_LIBRARIES TRUE) -add_executable (EGL2 ${MEMFS_FILE_SOURCES} ${FILE_SOURCES}) + if (CMAKE_BUILD_TYPE EQUAL "DEBUG") + set(wxWidgets_USE_DEBUG ON) + else() + set(wxWidgets_USE_DEBUG OFF) + endif() + + set(wxWidgets_USE_STATIC ON) + set(wxWidgets_USE_UNICODE ON) + + find_package(wxWidgets REQUIRED COMPONENTS core base png zlib) + include(${wxWidgets_USE_FILE}) + target_link_libraries(EGL2 PUBLIC ${wxWidgets_LIBRARIES}) +else() + message("Building without GUI") + + aux_source_directory(cmd INTERFACE_FILE_SOURCES) + add_executable(EGL2 ${INTERFACE_FILE_SOURCES} ${FILESYSTEM_FILE_SOURCES} ${STORAGE_FILE_SOURCES} ${WEB_FILE_SOURCES} ${FILE_SOURCES}) +endif() set_property(TARGET EGL2 PROPERTY CXX_STANDARD 20) -target_include_directories(EGL2 PRIVATE "$ENV{ProgramFiles\(x86\)}\\WinFsp\\inc") -target_link_libraries(EGL2 "$ENV{ProgramFiles\(x86\)}\\WinFsp\\lib\\winfsp-x64.lib") -configure_file(winfsp-x64.dll winfsp-x64.dll COPYONLY) \ No newline at end of file +find_package(OpenSSL REQUIRED) +find_package(RapidJSON CONFIG REQUIRED) +find_package(ZLIB REQUIRED) +find_package(lz4 REQUIRED) + +target_link_options(EGL2 PRIVATE "/DELAYLOAD:winfsp-x64.dll") +target_include_directories(EGL2 PRIVATE "$ENV{ProgramFiles\(x86\)}\\WinFsp\\inc" ${RAPIDJSON_INCLUDE_DIRS} "libdeflate") +target_link_libraries(EGL2 PRIVATE "$ENV{ProgramFiles\(x86\)}\\WinFsp\\lib\\winfsp-x64.lib" OpenSSL::SSL OpenSSL::Crypto Crypt32 ZLIB::ZLIB lz4::lz4 delayimp "${CMAKE_CURRENT_SOURCE_DIR}\\libdeflate\\libdeflatestatic.lib") \ No newline at end of file diff --git a/CMakeSettings.json b/CMakeSettings.json index 9962fae..c560a4a 100644 --- a/CMakeSettings.json +++ b/CMakeSettings.json @@ -7,9 +7,21 @@ "inheritEnvironments": [ "msvc_x64_x64" ], "buildRoot": "${projectDir}\\out\\build\\${name}", "installRoot": "${projectDir}\\out\\install\\${name}", - "cmakeCommandArgs": "", + "cmakeCommandArgs": "-DVCPKG_TARGET_TRIPLET=x64-windows-static", "buildCommandArgs": "-v", "ctestCommandArgs": "", + "variables": [ ] + }, + { + "name": "x64-RelDbg", + "generator": "Ninja", + "configurationType": "RelWithDebInfo", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "-DVCPKG_TARGET_TRIPLET=x64-windows-static", + "buildCommandArgs": "-v", + "ctestCommandArgs": "", + "inheritEnvironments": [ "msvc_x64_x64" ], "variables": [] }, { @@ -18,10 +30,11 @@ "configurationType": "Release", "buildRoot": "${projectDir}\\out\\build\\${name}", "installRoot": "${projectDir}\\out\\install\\${name}", - "cmakeCommandArgs": "", + "cmakeCommandArgs": "-DVCPKG_TARGET_TRIPLET=x64-windows-static", "buildCommandArgs": "-v", "ctestCommandArgs": "", - "inheritEnvironments": [ "msvc_x64_x64" ] + "inheritEnvironments": [ "msvc_x64_x64" ], + "variables": [] } ] } \ No newline at end of file diff --git a/MountedBuild.cpp b/MountedBuild.cpp new file mode 100644 index 0000000..fd0878b --- /dev/null +++ b/MountedBuild.cpp @@ -0,0 +1,501 @@ +#include "MountedBuild.h" + +#include "containers/iterable_queue.h" +#include "containers/semaphore.h" +#include "containers/file_sha.h" + +#include +#include +#include +#include +#include +#include + +#define fail(format, ...) FspServiceLog(EVENTLOG_ERROR_TYPE, format, ##__VA_ARGS__) + +#define SDDL_OWNER "S-1-5-18" // Local System +#define SDDL_DATA "P(A;ID;FRFX;;;WD)" +#define LOG_FLAGS 0 // can also be -1 for all flags + +#define SDDL_ROOT "D:" SDDL_DATA +#define SDDL_MEMFS "O:" SDDL_OWNER "G:" SDDL_OWNER "D:" SDDL_DATA + +MountedBuild::MountedBuild(MANIFEST* manifest, fs::path mountDir, fs::path cachePath, ErrorHandler error) { + this->Manifest = manifest; + this->MountDir = mountDir; + this->CacheDir = cachePath; + this->Error = error; + this->Storage = nullptr; + this->Memfs = nullptr; +} + +MountedBuild::~MountedBuild() { + Unmount(); + if (Storage) { + StorageDelete(Storage); + } +} + +bool MountedBuild::SetupCacheDirectory() { + if (!fs::is_directory(CacheDir) && !fs::create_directories(CacheDir)) { + LogError("can't create cachedir %s\n", CacheDir.string().c_str()); + return false; + } + + char cachePartFolder[3]; + for (int i = 0; i < 256; ++i) { + sprintf(cachePartFolder, "%02X", i); + fs::create_directory(CacheDir / cachePartFolder); + } + return true; +} + +void inline PreloadFile(STORAGE* Storage, MANIFEST_FILE* File, uint32_t ThreadCount, cancel_flag& cancelFlag) { + MANIFEST_CHUNK_PART* ChunkParts; + uint32_t ChunkCount; + uint16_t ChunkStride; + ManifestFileGetChunks(File, &ChunkParts, &ChunkCount, &ChunkStride); + + iterable_queue threads; + for (int i = 0, n = 0; i < ChunkCount * ChunkStride && !cancelFlag.cancelled(); i += ChunkStride) { + auto Chunk = ManifestFileChunkGetChunk((MANIFEST_CHUNK_PART*)((char*)ChunkParts + i)); + if (StorageChunkDownloaded(Storage, Chunk)) { + continue; + } + + // cheap semaphore, keeps thread count low instead of having 81k threads pile up + while (threads.size() >= ThreadCount) { + threads.front().join(); + threads.pop(); + } + + threads.push(std::thread(StorageDownloadChunk, Storage, Chunk, [&n, &ChunkCount](const char* buf, uint32_t bufSize) + { + printf("\r%d downloaded / %d total (%.2f%%)", ++n, ChunkCount, float(n * 100) / ChunkCount); + })); + } + + if (cancelFlag.cancelled()) { + for (auto& thread : threads) { + thread.detach(); + } + } + else { + for (auto& thread : threads) { + thread.join(); + } + } + + printf("\r%d downloaded / %d total (%.2f%%)\n", ChunkCount, ChunkCount, float(100)); +} + +bool inline CompareFile(MANIFEST_FILE* File, fs::path FilePath) { + if (fs::status(FilePath).type() != fs::file_type::regular) { + return false; + } + + if (fs::file_size(FilePath) != ManifestFileGetFileSize(File)) { + return false; + } + + char FileSha[20]; + if (!SHAFile(FilePath, FileSha)) { + return false; + } + + return !memcmp(FileSha, ManifestFileGetSha1(File), 20); +} + +bool MountedBuild::SetupGameDirectory(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag, uint32_t threadCount, fs::path gameDir) { + if (!fs::is_directory(gameDir) && !fs::create_directories(gameDir)) { + LogError("can't create gamedir %s\n", gameDir.string().c_str()); + return false; + } + + MANIFEST_FILE* Files; + uint32_t FileCount; + uint16_t FileStride; + char FilenameBuffer[128]; + ManifestGetFiles(Manifest, &Files, &FileCount, &FileStride); + setMax(FileCount); + for (int i = 0; i < FileCount * FileStride && !cancelFlag.cancelled(); i += FileStride) { + auto File = (MANIFEST_FILE*)((char*)Files + i); + ManifestFileGetName(File, FilenameBuffer); + fs::path filePath = fs::path(FilenameBuffer); + fs::path folderPath = filePath.parent_path(); + + if (!fs::create_directories(gameDir / folderPath) && !fs::is_directory(gameDir / folderPath)) { + LogError("can't create %s\n", (gameDir / folderPath).string().c_str()); + goto continueFileLoop; + } + do { + if (folderPath.filename() == "Binaries") { + if (!CompareFile(File, gameDir / filePath)) { + PreloadFile(Storage, File, threadCount, cancelFlag); + if (!fs::copy_file(MountDir / filePath, gameDir / filePath, fs::copy_options::overwrite_existing)) { + LogError("failed to copy %s\n", filePath.string().c_str()); + } + } + goto continueFileLoop; + } + folderPath = folderPath.parent_path(); + } while (folderPath != folderPath.root_path()); + + if (!fs::is_symlink(gameDir / filePath)) { + fs::create_symlink(MountDir / filePath, gameDir / filePath); + } + continueFileLoop: + progress(); + } + return true; +} + +bool MountedBuild::StartStorage(uint32_t storageFlags) { + if (Storage) { + return true; + } + char CloudDirHost[64]; + char CloudDirPath[64]; + ManifestGetCloudDir(Manifest, CloudDirHost, CloudDirPath); + if (!StorageCreate(storageFlags, CacheDir.native().c_str(), CloudDirHost, CloudDirPath, &Storage)) { + LogError("cannot create storage"); + return false; + } + return true; +} + +bool MountedBuild::PreloadAllChunks(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag, uint32_t threadCount) { + std::shared_ptr* ChunkList; + uint32_t ChunkCount; + ManifestGetChunks(Manifest, &ChunkList, &ChunkCount); + LogError("set chunk count to %u", ChunkCount); + setMax(ChunkCount); + + iterable_queue threads; + + for (auto Chunk = ChunkList; Chunk != ChunkList + ChunkCount && !cancelFlag.cancelled(); Chunk++) { + if (StorageChunkDownloaded(Storage, Chunk->get())) { + //LogError("already downloaded", ChunkCount); + progress(); + continue; + } + + // cheap semaphore, keeps thread count low instead of having 81k threads pile up + while (threads.size() >= threadCount) { + threads.front().join(); + threads.pop(); + } + + threads.push(std::thread(StorageDownloadChunk, Storage, Chunk->get(), std::bind(progress))); + } + + if (cancelFlag.cancelled()) { + for (auto& thread : threads) { + thread.detach(); + progress(); + } + } + else { + for (auto& thread : threads) { + thread.join(); + progress(); + } + } + return true; +} + +#define HTONLL(x) ((1==htonl(1)) ? (x) : (((uint64_t)htonl((x) & 0xFFFFFFFFUL)) << 32) | htonl((uint32_t)((x) >> 32))) +#define NTOHLL(x) ((1==ntohl(1)) ? (x) : (((uint64_t)ntohl((x) & 0xFFFFFFFFUL)) << 32) | ntohl((uint32_t)((x) >> 32))) +auto hash = [](const char* n) { return (*((uint64_t*)n)) ^ (*(((uint64_t*)n) + 1)); }; +auto equal = [](const char* a, const char* b) {return !memcmp(a, b, 16); }; +void MountedBuild::PurgeUnusedChunks(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag) { + std::shared_ptr* ChunkList; + uint32_t ChunkCount; + ManifestGetChunks(Manifest, &ChunkList, &ChunkCount); + + std::unordered_set ManifestGuids; + ManifestGuids.reserve(ChunkCount); + for (auto Chunk = ChunkList; Chunk != ChunkList + ChunkCount; Chunk++) { + ManifestGuids.insert(ManifestChunkGetGuid(Chunk->get())); + } + + setMax(std::count_if(fs::recursive_directory_iterator(CacheDir), fs::recursive_directory_iterator(), [](const fs::directory_entry& f) { return f.is_regular_file(); })); + + char guidBuffer[16]; + char guidBuffer2[16]; + for (auto& p : fs::recursive_directory_iterator(CacheDir)) { + if (cancelFlag.cancelled()) { + break; + } + if (!p.is_regular_file()) { + continue; + } + + sscanf(p.path().filename().string().c_str(), "%016llX%016llX", guidBuffer2, guidBuffer2 + 8); + *(unsigned long long*)guidBuffer = HTONLL(*(unsigned long long*)guidBuffer2); + *(unsigned long long*)(guidBuffer + 8) = HTONLL(*(unsigned long long*)(guidBuffer2 + 8)); + if (!ManifestGuids.erase(guidBuffer)) { + fs::remove(p); + } + progress(); + } +} + +void MountedBuild::VerifyAllChunks(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag, uint32_t threadCount) { + std::shared_ptr* ChunkList; + uint32_t ChunkCount; + ManifestGetChunks(Manifest, &ChunkList, &ChunkCount); + setMax(ChunkCount); + + iterable_queue threads; + + for (auto Chunk = ChunkList; Chunk != ChunkList + ChunkCount && !cancelFlag.cancelled(); Chunk++) { + if (!StorageChunkDownloaded(Storage, Chunk->get())) { + continue; + } + + // cheap semaphore, keeps thread count low instead of having 81k threads pile up + while (threads.size() >= threadCount) { + threads.front().join(); + progress(); + threads.pop(); + } + + threads.push(std::thread(StorageVerifyChunk, Storage, Chunk->get())); + } + + if (cancelFlag.cancelled()) { + for (auto& thread : threads) { + thread.detach(); + } + } + else { + for (auto& thread : threads) { + thread.join(); + progress(); + } + } +} + +void MountedBuild::LaunchGame(fs::path gameDir, const char* additionalArgs) { + char ExeBuf[MAX_PATH]; + char CmdBuf[512]; + ManifestGetLaunchInfo(Manifest, ExeBuf, CmdBuf); + strcat(CmdBuf, " "); + strcat(CmdBuf, additionalArgs); + fs::path exePath = gameDir / ExeBuf; + + PROCESS_INFORMATION pi; + STARTUPINFOA si; + + memset(&pi, 0, sizeof(pi)); + memset(&si, 0, sizeof(si)); + si.cb = sizeof(si); + + CreateProcessA(exePath.string().c_str(), CmdBuf, NULL, NULL, FALSE, DETACHED_PROCESS, NULL, exePath.parent_path().string().c_str(), &si, &pi); + + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); +} + +bool MountedBuild::Mount() { + if (Mounted()) { + return true; + } + + NTSTATUS Result; + + FspDebugLogSetHandle(GetStdHandle(STD_ERROR_HANDLE)); + Provider = CreateProvider( // might be a better way (lambdas cause dynamic memory allocation, but idk) + std::bind(&MountedBuild::FileOpen, this, std::placeholders::_1, std::placeholders::_2), + std::bind(&MountedBuild::FileClose, this, std::placeholders::_1), + std::bind(&MountedBuild::FileRead, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5) + ); + + + { + PVOID securityDescriptor; + ULONG securityDescriptorSize; + if (!ConvertStringSecurityDescriptorToSecurityDescriptorA(SDDL_MEMFS, SDDL_REVISION_1, &securityDescriptor, &securityDescriptorSize)) { + Result = FspNtStatusFromWin32(GetLastError()); + fail(L"invalid sddl: %08x", Result); + goto exit; + } + + auto downloadSize = ManifestDownloadSize(Manifest); + auto installSize = ManifestInstallSize(Manifest); + Result = MemfsCreateFunnel( + MemfsDisk, // flags + INFINITE, // file timeout + 1024, // max file nodes/files + L"EGL2", // file system name + 0, // volume prefix (instead of \\server or whatever), can be optional + L"EGL2", // volume label + installSize, + installSize - downloadSize, + securityDescriptor, + securityDescriptorSize, + Provider, + &Memfs); + LocalFree(securityDescriptor); + } + + if (!NT_SUCCESS(Result)) + { + fail(L"cannot create MEMFS"); + goto exit; + } + + { + MANIFEST_FILE* Files; + uint32_t FileCount; + uint16_t FileStride; + char FilenameBuffer[128]; + FilenameBuffer[0] = '/'; + ManifestGetFiles(Manifest, &Files, &FileCount, &FileStride); + std::set directories; + for (int i = 0; i < FileCount * FileStride; i += FileStride) { + ManifestFileGetName((MANIFEST_FILE*)((char*)Files + i), FilenameBuffer + 1); + fs::path curPath = fs::path(FilenameBuffer).parent_path().make_preferred(); + do { + directories.insert(curPath); + curPath = curPath.parent_path(); + } while (curPath != curPath.root_path()); + } + for (auto& dir : directories) { + Result = CreateFsFile(Memfs, (PWSTR)dir.native().c_str(), true); + if (!NT_SUCCESS(Result)) + { + fail(L"cannot create directory %08x", Result); + goto exit; + } + } + for (int i = 0; i < FileCount * FileStride; i += FileStride) { + ManifestFileGetName((MANIFEST_FILE*)((char*)Files + i), FilenameBuffer + 1); + Result = CreateFsFile(Memfs, (PWSTR)fs::path(FilenameBuffer).make_preferred().native().c_str(), false); + if (!NT_SUCCESS(Result)) + { + fail(L"cannot create file %08x", Result); + goto exit; + } + } + } + + FspFileSystemSetDebugLog(MemfsFileSystem(Memfs), LOG_FLAGS); + + { + PVOID rootSecurity; + if (!ConvertStringSecurityDescriptorToSecurityDescriptorA(SDDL_ROOT, SDDL_REVISION_1, &rootSecurity, NULL)) { + Result = FspNtStatusFromWin32(GetLastError()); + fail(L"invalid root sddl: %08x", Result); + goto exit; + } + Result = FspFileSystemSetMountPointEx(MemfsFileSystem(Memfs), (PWSTR)MountDir.native().c_str(), rootSecurity); + LocalFree(rootSecurity); + } + + if (!NT_SUCCESS(Result)) + { + fail(L"cannot mount MEMFS %08x", Result); + goto exit; + } + + Result = MemfsStart(Memfs); + if (!NT_SUCCESS(Result)) + { + fail(L"cannot start MEMFS %08x", Result); + goto exit; + } + + Result = STATUS_SUCCESS; + +exit: + printf("result %d\n", Result); + if (!NT_SUCCESS(Result) && Memfs) { + MemfsDelete(Memfs); + Memfs = 0; + } + + if (!NT_SUCCESS(Result)) { + fail(L"Failed: %d", Result); + } + + return NT_SUCCESS(Result); +} + +bool MountedBuild::Unmount() { + if (!Mounted()) { + return true; + } + + MemfsStop(Memfs); + MemfsDelete(Memfs); + Memfs = 0; + + return true; +} + +bool MountedBuild::Mounted() { + return Memfs; +} + +void MountedBuild::LogError(const char* format, ...) +{ + va_list argp; + va_start(argp, format); + char* buf = new char[snprintf(nullptr, 0, format, argp) + 1]; + vsprintf(buf, format, argp); + va_end(argp); + Error(buf); + delete[] buf; +} + +PVOID MountedBuild::FileOpen(PCWSTR fileName, UINT64* fileSize) { + auto File = ManifestGetFile(Manifest, fs::path(fileName).generic_string().c_str() + 1); + if (File) { + *fileSize = ManifestFileGetFileSize(File); + } + else { + *fileSize = 0; + } + return File; +} + +void MountedBuild::FileClose(PVOID Handle) { + // no need to do anything, it's just a MANIFEST_FILE* +} + +void MountedBuild::FileRead(PVOID Handle, PVOID Buffer, UINT64 offset, ULONG length, ULONG* bytesRead) { + auto File = (MANIFEST_FILE*)Handle; + uint32_t ChunkStartIndex, ChunkStartOffset; + if (ManifestFileGetChunkIndex(File, offset, &ChunkStartIndex, &ChunkStartOffset)) { + MANIFEST_CHUNK_PART* ChunkParts; + uint32_t ChunkPartCount, BytesRead = 0; + uint16_t StrideSize; + ManifestFileGetChunks(File, &ChunkParts, &ChunkPartCount, &StrideSize); + for (int i = ChunkStartIndex * StrideSize; i < ChunkPartCount * StrideSize; i += StrideSize) { + auto chunkPart = (MANIFEST_CHUNK_PART*)((char*)ChunkParts + i); + uint32_t ChunkOffset, ChunkSize; + ManifestFileChunkGetData(chunkPart, &ChunkOffset, &ChunkSize); + char* ChunkBuffer = new char[ChunkSize]; + StorageDownloadChunkPart(Storage, chunkPart, ChunkBuffer); + if (((int64_t)length - (int64_t)BytesRead) > (int64_t)ChunkSize - (int64_t)ChunkStartOffset) { // copy the entire buffer over + memcpy((char*)Buffer + BytesRead, ChunkBuffer + ChunkStartOffset, ChunkSize - ChunkStartOffset); + BytesRead += ChunkSize - ChunkStartOffset; + } + else { // copy what it needs to fill up the rest + memcpy((char*)Buffer + BytesRead, ChunkBuffer + ChunkStartOffset, length - BytesRead); + BytesRead += (int64_t)length - (int64_t)BytesRead; + delete[] ChunkBuffer; + *bytesRead = BytesRead; + return; + } + delete[] ChunkBuffer; + ChunkStartOffset = 0; + } + *bytesRead = BytesRead; + } + else { + *bytesRead = 0; + } +} \ No newline at end of file diff --git a/MountedBuild.h b/MountedBuild.h new file mode 100644 index 0000000..f40fb3f --- /dev/null +++ b/MountedBuild.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +#include "containers/cancel_flag.h" +#include "filesystem/memfs.h" +#include "web/manifest.h" +#include "storage/storage.h" + +namespace fs = std::filesystem; + +typedef std::function ProgressSetMaxHandler; +typedef std::function ProgressIncrHandler; +typedef std::function ErrorHandler; + +class MountedBuild { +public: + MountedBuild(MANIFEST* manifest, fs::path mountDir, fs::path cachePath, ErrorHandler error); + ~MountedBuild(); + + bool SetupCacheDirectory(); + bool SetupGameDirectory(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag, uint32_t threadCount, fs::path gameDir); + bool StartStorage(uint32_t storageFlags); + bool PreloadAllChunks(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag, uint32_t threadCount); + void PurgeUnusedChunks(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag); + void VerifyAllChunks(ProgressSetMaxHandler setMax, ProgressIncrHandler progress, cancel_flag& cancelFlag, uint32_t threadCount); + void LaunchGame(fs::path gameDir, const char* additionalArgs); + bool Mount(); + bool Unmount(); + bool Mounted(); + +private: + void LogError(const char* format, ...); + + PVOID FileOpen(PCWSTR fileName, UINT64* fileSize); + void FileClose(PVOID Handle); + void FileRead(PVOID Handle, PVOID Buffer, UINT64 offset, ULONG length, ULONG* bytesRead); + + fs::path MountDir; + fs::path CacheDir; + MANIFEST* Manifest; + STORAGE* Storage; + MEMFS* Memfs; + MEMFS_FILE_PROVIDER* Provider; + ErrorHandler Error; +}; \ No newline at end of file diff --git a/README.md b/README.md index cd7a8d6..cf5fd1a 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,48 @@ # EGL2 - ## An alternative to the Epic Games Launcher ## Features - Extremely quick updating - - Or, download the update it in the background while playing! (Results may vary) - - Load multiple versions of games side by side + - Or, download the update it in the background while playing! (Not recommended, but possible!) - Implements directly with your file system - Low memory footprint - Use your favorite datamining tools or extractors seamlessly - Trade performance for disk space, keep your game install compressed + - Not added yet: + - Load multiple versions of games side by side -## Installation -The installation process can be a little awkward, but here's a simple step by step guide: -TODO: Add instructions for WinFsp +## Installation / Getting Started +The process can be a little awkward, so here's a simple step by step guide: + 1. Install [WinFsp](http://www.secfs.net/winfsp/), which can be downloaded [here](https://github.com/billziss-gh/winfsp/releases/download/v1.7B1/winfsp-1.7.20038.msi). (Only the "Core" feature is necessary) + 2. Download the latest install of EGL2 [here](https://github.com/WorkingRobot/EGL2/releases/latest/download/EGL2.exe). + 3. When ran, you should be greeted with a window with a several buttons. If instead you see an error, follow the instructions it shows. + 4. Click "Setup", where a new window will prompt you to setup where you want your data to be stored. + 5. If you want to be able to play the game (and not just download it), you will need to select a game directory as well. If you do not want to play, deselect the checkbox in the bottom right corner. + 6. By default, it is configured to download and install the game in a decompressed state. This will simply download and decompress your game to it's full size. + 7. Feel free to use the other compression methods at the expense of download/update time (due to compression). I recommend using "LZ4" and using the slowest compression level setting you feel comfortable with. LZ4 has decompression speeds that are extremely fast (~4 GB/s), and the slowest compression level can compress your install by over 2x. The slowest setting with LZ4 can compress ~40MB/s on a good computer, so do be conservative if your computer isn't great. + 8. Close the setup window when done. + 9. If the folders you chose in your setup exist, the other buttons should be enabled. + 10. Click "Update". This will take about an hour or so depending on the settings you chose and the computer you have. (You are downloading an entire game and possibly compressing it, too, y'know.) + 11. When your game finishes updating/installing, click "Start". + 12. If all goes well, you should be able to click "Play". Once you enter in your exchange code, Fortnite should start! + 13. When relaunching, it's recommended to click "Update" again, in the event that a new hotfix or update has occurred that you weren't aware of. + 14. You will need to relaunch EGL2 if an update was pushed after starting. ## Building -TODO: Ask them to add the winfsp dll and whatnot \ No newline at end of file +I use [CMake for Visual Studio](https://docs.microsoft.com/en-us/cpp/build/cmake-projects-in-visual-studio), so your results may vary. + +### Prerequisites + - CMake + - MSVC (with C++20 support) + - vcpkg + - (Install these packages with `x64-windows-static`) + - OpenSSL + - RapidJSON + - zLib + - LZ4 + - wxWidgets (if compiling with a GUI) + - WinFsp with the "Developer" feature installed + +### CMake Build Options + - `WITH_GUI` - if set to false, a command line version will be built instead. (Going to be honest, I haven't checked for any crashes/errors in the console version, but it should work) + - `WX_DIR` - if you want to build EGL2 with a GUI, set this path to your wxWidgets directory. diff --git a/cmd/cmdmain.cpp b/cmd/cmdmain.cpp new file mode 100644 index 0000000..998bb1e --- /dev/null +++ b/cmd/cmdmain.cpp @@ -0,0 +1,232 @@ +#include "../MountedBuild.h" +#include "../winfspcheck.h" + +static inline const char* humanSize(uint64_t bytes) +{ + char* suffix[] = { "B", "KB", "MB", "GB", "TB" }; + char length = sizeof(suffix) / sizeof(suffix[0]); + + int i = 0; + double dblBytes = bytes; + + if (bytes > 1024) { + for (i = 0; (bytes / 1024) > 0 && i < length - 1; i++, bytes /= 1024) + dblBytes = bytes / 1024.0; + } + + static char output[16]; + sprintf(output, "%.04lf %s", dblBytes, suffix[i]); + return output; +} + +#define argtos(v) if (arge > ++argp) v = *argp; else goto usage +#define argtol(v) if (arge > ++argp) v = wcstol_deflt(*argp, v); else goto usage + +static inline ULONG wcstol_deflt(wchar_t* w, ULONG deflt) +{ + wchar_t* endp; + ULONG ul = wcstol(w, &endp, 0); + return L'\0' != w[0] && L'\0' == *endp ? ul : deflt; +} + +int wmain(ULONG argc, PWSTR* argv) +{ + { + auto result = LoadWinFsp(); + if (result != WinFspCheckResult::LOADED) { + switch (result) + { + case WinFspCheckResult::CANNOT_ENUMERATE: + printf("Could not iterate over drivers to get WinFsp install. System-specific error: %d\n", GetLastError()); + break; + case WinFspCheckResult::NOT_FOUND: + printf("Could not find WinFsp as an installed driver. Maybe you don't have it installed?\n"); + break; + case WinFspCheckResult::NO_DLL: + printf("Could not find WinFsp's DLL in the driver's folder. Try reinstalling WinFsp.\n"); + break; + case WinFspCheckResult::CANNOT_LOAD: + printf("Could not load WinFsp's DLL in the driver's folder. Try reinstalling WinFsp.\n"); + break; + default: + printf("An unknown error occurred when trying to load WinFsp's DLL: %d\n", result); + break; + } + return 0; + } + } + + wchar_t** argp, ** arge; + + PWSTR GamePath = 0, CachePath = 0, MountPath = 0; + ULONG DownloadThreadCount = 0, CompressionMethod = 0, CompressionLevel = 0; + bool Verify = false, RemoveUnused = false; + for (argp = argv + 1, arge = argv + argc; arge > argp; argp++) + { + if (L'-' != argp[0][0]) + break; + switch (argp[0][1]) + { + case L'?': + goto usage; + case L'G': + argtos(GamePath); + break; + case L'P': + argtol(DownloadThreadCount); + break; + case L'C': + argtos(CachePath); + break; + case L'M': + argtos(MountPath); + break; + case L'R': + RemoveUnused = true; + break; + case L'V': + Verify = true; + break; + case 'S': + argtol(CompressionMethod); + break; + case 's': + argtol(CompressionLevel); + break; + default: + goto usage; + } + } + if (!CachePath || !MountPath) { + goto usage; + } + uint32_t StorageFlags = 0; + if (Verify) { + StorageFlags |= StorageVerifyHashes; + } + if (CompressionMethod) { + switch (CompressionMethod) + { + case 1: + StorageFlags |= StorageDecompressed; + break; + case 2: + StorageFlags |= StorageCompressed; + break; + case 3: + StorageFlags |= StorageCompressLZ4; + break; + case 4: + StorageFlags |= StorageCompressZlib; + break; + default: + wprintf(L"Unknown compression method %d\n", CompressionMethod); + goto usage; + break; + } + } + else { + StorageFlags |= StorageCompressLZ4; + } + if (CompressionLevel) { + switch (CompressionLevel) + { + case 1: + StorageFlags |= StorageCompressFastest; + break; + case 2: + StorageFlags |= StorageCompressFast; + break; + case 3: + StorageFlags |= StorageCompressNormal; + break; + case 4: + StorageFlags |= StorageCompressSlow; + break; + case 5: + StorageFlags |= StorageCompressSlowest; + break; + default: + wprintf(L"Unknown compression level %d\n", CompressionLevel); + goto usage; + break; + } + } + else { + StorageFlags |= StorageCompressSlowest; + } + + MANIFEST* manifest; + MANIFEST_AUTH* auth; + STORAGE* storage; + ManifestAuthGrab(&auth); + ManifestAuthGetManifest(auth, "", &manifest); + printf("Total download size: %s\n", humanSize(ManifestInstallSize(manifest))); + (void)getchar(); + printf("Making build\n"); + MountedBuild* Build = new MountedBuild(manifest, MountPath, CachePath, [](const char* error) {printf("%s\n", error); }); + + printf("Setting up cache dir\n"); + if (!Build->SetupCacheDirectory()) { + printf("failed to setup cache dir\n"); + } + + printf("starting storage\n"); + if (!Build->StartStorage(StorageFlags)) { + printf("failed to start storage\n"); + } + + if (RemoveUnused) { + cancel_flag flag; + Build->PurgeUnusedChunks([](uint32_t max) {}, []() {}, flag); + } + + if (DownloadThreadCount) { + printf("Predownloading\n"); + cancel_flag flag; + if (!Build->PreloadAllChunks([](uint32_t max) {}, []() {}, flag, DownloadThreadCount)) { + printf("failed to preload\n"); + } + } + + printf("Starting\n"); + if (!Build->Mount()) { + printf("failed to start\n"); + } + + if (GamePath) { + printf("Setting up game dir\n"); + cancel_flag flag; + if (!Build->SetupGameDirectory([](uint32_t max) {}, []() {}, flag, DownloadThreadCount ? DownloadThreadCount : 32, GamePath)) { + printf("failed to setup game dir\n"); + } + } + + printf("Started, press any key to close\n"); + (void)getchar(); + printf("Closing\n"); + delete Build; + printf("Closed\n"); + return STATUS_SUCCESS; + +usage: + static char usage[] = "" + "usage: EGL2.exe OPTIONS\n" + "\n" + "options:\n" + " -G GameMountDir [optional: make Fortnite launchable here]\n" + " -P ThreadCount [optional: predownload all chunks]\n" + " -C CacheDir [directory to place downloaded chunk files]\n" + " -M MountDir [drive (Y:) or directory to mount filesystem to]\n" + " -R [optional: before mounting, remove all chunks that aren't used in the manifest (for when updating)]\n" + " -V [optional: verify all read chunks to sha hashes and redownload if necessary]\n" + // 1: uncompressed, 2: as downloaded (zlib), 3: lz4, 4: zlib + " -S Compression [optional: Type of compression used (1-4)]\n" + // 1: fastest, 2: fast, 3: normal, 4: slow, 5: slowest + // only used when -S is 3 or 4 + " -s CompressionLvl [optional: Amount of compression used (1-5)]\n"; + + printf(usage); + + return STATUS_UNSUCCESSFUL; +} \ No newline at end of file diff --git a/containers/cancel_flag.h b/containers/cancel_flag.h new file mode 100644 index 0000000..b15845b --- /dev/null +++ b/containers/cancel_flag.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +class cancel_flag { +public: + bool cancelled() { + return value.load(); + } + + bool cancelled() const { + return value.load(); + } + + bool cancelled() volatile { + return value.load(); + } + + bool cancelled() const volatile { + return value.load(); + } + + void cancel() { + value.store(true); + } + + void cancel() volatile { + value.store(true); + } + +protected: + std::atomic value = false; +}; \ No newline at end of file diff --git a/containers/file_sha.h b/containers/file_sha.h new file mode 100644 index 0000000..cec848f --- /dev/null +++ b/containers/file_sha.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include + +#include + +inline bool SHAFile(fs::path path, char OutHash[SHA_DIGEST_LENGTH]) { + constexpr int buffer_size = 1 << 13; + char buffer[buffer_size]; + + SHA_CTX ctx; + SHA1_Init(&ctx); + + std::ifstream fp(path.c_str(), std::ios::in | std::ios::binary); + + if (!fp.good()) { + return false; + } + + while (fp.good()) { + fp.read(buffer, buffer_size); + SHA1_Update(&ctx, buffer, fp.gcount()); + } + fp.close(); + + SHA1_Final((unsigned char*)OutHash, &ctx); + return true; +} \ No newline at end of file diff --git a/containers/iterable_queue.h b/containers/iterable_queue.h new file mode 100644 index 0000000..33a93a8 --- /dev/null +++ b/containers/iterable_queue.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +template> +class iterable_queue : public std::queue +{ +public: + typedef typename Container::iterator iterator; + typedef typename Container::const_iterator const_iterator; + + iterator begin() { return this->c.begin(); } + iterator end() { return this->c.end(); } + const_iterator begin() const { return this->c.begin(); } + const_iterator end() const { return this->c.end(); } +}; \ No newline at end of file diff --git a/containers/semaphore.h b/containers/semaphore.h new file mode 100644 index 0000000..81f541f --- /dev/null +++ b/containers/semaphore.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +class semaphore +{ +private: + std::mutex mutex_; + std::condition_variable condition_; + unsigned long count_ = 0; // Initialized as locked. + +public: + semaphore(long initial_count) { + count_ = initial_count; + } + + void notify() { // a thread is released + std::lock_guard lock(mutex_); + ++count_; + condition_.notify_one(); + } + + void wait() { // wait for a thread to be released + std::unique_lock lock(mutex_); + while (!count_) // Handle spurious wake-ups. + condition_.wait(lock); + --count_; + } +}; \ No newline at end of file diff --git a/memfs/memfs.cpp b/filesystem/memfs.cpp similarity index 87% rename from memfs/memfs.cpp rename to filesystem/memfs.cpp index 570b4db..0c49a3a 100644 --- a/memfs/memfs.cpp +++ b/filesystem/memfs.cpp @@ -1,25 +1,4 @@ -/** - * @file memfs.cpp - * - * @copyright 2015-2020 Bill Zissimopoulos - */ - /* - * This file is part of WinFsp. - * - * You can redistribute it and/or modify it under the terms of the GNU - * General Public License version 3 as published by the Free Software - * Foundation. - * - * Licensees holding a valid commercial license may use this software - * in accordance with the commercial license agreement provided in - * conjunction with the software. The terms and conditions of any such - * commercial license agreement shall govern, supersede, and render - * ineffective any application of the GPLv3 license to this software, - * notwithstanding of any reference thereto in the software or - * associated repository. - */ - -//#undef _DEBUG + #include "memfs.h" #include #include @@ -35,13 +14,13 @@ typedef struct _MEMFS_FILE_PROVIDER { std::function Open; std::function Close; - std::function Read; + std::function Read; } MEMFS_FILE_PROVIDER; MEMFS_FILE_PROVIDER* CreateProvider( std::function Open, std::function Close, - std::function Read) + std::function Read) { auto Provider = new MEMFS_FILE_PROVIDER; Provider->Open = Open; @@ -55,10 +34,6 @@ void CloseProvider(MEMFS_FILE_PROVIDER* Provider) delete Provider; } -/* - * MEMFS - */ - static inline UINT64 MemfsGetSystemTime(VOID) { @@ -146,8 +121,6 @@ typedef struct _MEMFS_FILE_NODE { WCHAR FileName[MEMFS_MAX_PATH]; FSP_FSCTL_FILE_INFO FileInfo; - SIZE_T FileSecuritySize; - PVOID FileSecurity; PVOID FileData; volatile LONG RefCount; } MEMFS_FILE_NODE; @@ -171,8 +144,11 @@ typedef struct _MEMFS MEMFS_FILE_NODE_MAP* FileNodeMap; MEMFS_FILE_PROVIDER* FileProvider; ULONG MaxFileNodes; - UINT16 VolumeLabelLength; + UINT64 VolumeTotal; + UINT64 VolumeFree; WCHAR VolumeLabel[32]; + PVOID Security; + SIZE_T SecuritySize; } MEMFS; static inline @@ -205,7 +181,6 @@ static inline VOID MemfsFileNodeDelete(MEMFS_FILE_PROVIDER* Provider, MEMFS_FILE_NODE* FileNode) { Provider->Close(FileNode->FileData); - free(FileNode->FileSecurity); free(FileNode); } @@ -534,10 +509,10 @@ static NTSTATUS GetVolumeInfo(FSP_FILE_SYSTEM* FileSystem, { MEMFS* Memfs = (MEMFS*)FileSystem->UserContext; - VolumeInfo->TotalSize = Memfs->MaxFileNodes * (UINT64)4096; - VolumeInfo->FreeSize = (Memfs->MaxFileNodes - MemfsFileNodeMapCount(Memfs->FileNodeMap)) * 4096; - VolumeInfo->VolumeLabelLength = Memfs->VolumeLabelLength; - memcpy(VolumeInfo->VolumeLabel, Memfs->VolumeLabel, Memfs->VolumeLabelLength); + VolumeInfo->TotalSize = Memfs->VolumeTotal; + VolumeInfo->FreeSize = Memfs->VolumeFree; + VolumeInfo->VolumeLabelLength = wcslen(Memfs->VolumeLabel) * sizeof(WCHAR); + wcscpy_s(VolumeInfo->VolumeLabel, Memfs->VolumeLabel); return STATUS_SUCCESS; } @@ -546,20 +521,7 @@ static NTSTATUS SetVolumeLabel(FSP_FILE_SYSTEM* FileSystem, PWSTR VolumeLabel, FSP_FSCTL_VOLUME_INFO* VolumeInfo) { - MEMFS* Memfs = (MEMFS*)FileSystem->UserContext; - - Memfs->VolumeLabelLength = (UINT16)(wcslen(VolumeLabel) * sizeof(WCHAR)); - if (Memfs->VolumeLabelLength > sizeof Memfs->VolumeLabel) - Memfs->VolumeLabelLength = sizeof Memfs->VolumeLabel; - memcpy(Memfs->VolumeLabel, VolumeLabel, Memfs->VolumeLabelLength); - - VolumeInfo->TotalSize = Memfs->MaxFileNodes * 4096; - VolumeInfo->FreeSize = - (Memfs->MaxFileNodes - MemfsFileNodeMapCount(Memfs->FileNodeMap)) * 4096; - VolumeInfo->VolumeLabelLength = Memfs->VolumeLabelLength; - memcpy(VolumeInfo->VolumeLabel, Memfs->VolumeLabel, Memfs->VolumeLabelLength); - - return STATUS_SUCCESS; + return STATUS_ACCESS_DENIED; } static NTSTATUS GetSecurityByName(FSP_FILE_SYSTEM* FileSystem, @@ -583,15 +545,16 @@ static NTSTATUS GetSecurityByName(FSP_FILE_SYSTEM* FileSystem, if (0 != PSecurityDescriptorSize) { - if (FileNode->FileSecuritySize > * PSecurityDescriptorSize) + if (Memfs->SecuritySize > *PSecurityDescriptorSize) { - *PSecurityDescriptorSize = FileNode->FileSecuritySize; + *PSecurityDescriptorSize = Memfs->SecuritySize; return STATUS_BUFFER_OVERFLOW; } - *PSecurityDescriptorSize = FileNode->FileSecuritySize; + *PSecurityDescriptorSize = Memfs->SecuritySize; + if (0 != SecurityDescriptor) - memcpy(SecurityDescriptor, FileNode->FileSecurity, FileNode->FileSecuritySize); + memcpy(SecurityDescriptor, Memfs->Security, Memfs->SecuritySize); } return STATUS_SUCCESS; @@ -729,8 +692,6 @@ NTSTATUS Flush(FSP_FILE_SYSTEM* FileSystem, { MEMFS_FILE_NODE* FileNode = (MEMFS_FILE_NODE*)FileNode0; - /* nothing to flush, since we do not cache anything */ - if (0 != FileNode) { MemfsFileNodeGetFileInfo(FileNode, FileInfo); @@ -813,17 +774,19 @@ static NTSTATUS GetSecurity(FSP_FILE_SYSTEM* FileSystem, PVOID FileNode0, PSECURITY_DESCRIPTOR SecurityDescriptor, SIZE_T* PSecurityDescriptorSize) { + MEMFS* Memfs = (MEMFS*)FileSystem->UserContext; + MEMFS_FILE_NODE* FileNode = (MEMFS_FILE_NODE*)FileNode0; - if (FileNode->FileSecuritySize > * PSecurityDescriptorSize) + if (Memfs->SecuritySize > * PSecurityDescriptorSize) { - *PSecurityDescriptorSize = FileNode->FileSecuritySize; + *PSecurityDescriptorSize = Memfs->SecuritySize; return STATUS_BUFFER_OVERFLOW; } - *PSecurityDescriptorSize = FileNode->FileSecuritySize; + *PSecurityDescriptorSize = Memfs->SecuritySize; if (0 != SecurityDescriptor) - memcpy(SecurityDescriptor, FileNode->FileSecurity, FileNode->FileSecuritySize); + memcpy(SecurityDescriptor, Memfs->Security, Memfs->SecuritySize); return STATUS_SUCCESS; } @@ -832,36 +795,7 @@ static NTSTATUS SetSecurity(FSP_FILE_SYSTEM* FileSystem, PVOID FileNode0, SECURITY_INFORMATION SecurityInformation, PSECURITY_DESCRIPTOR ModificationDescriptor) { - // to be fair, i have no idea when and how this runs, maybe add a logging statement to check back later? - - MEMFS_FILE_NODE* FileNode = (MEMFS_FILE_NODE*)FileNode0; - PSECURITY_DESCRIPTOR NewSecurityDescriptor, FileSecurity; - SIZE_T FileSecuritySize; - NTSTATUS Result; - - Result = FspSetSecurityDescriptor( - FileNode->FileSecurity, - SecurityInformation, - ModificationDescriptor, - &NewSecurityDescriptor); - if (!NT_SUCCESS(Result)) - return Result; - - FileSecuritySize = GetSecurityDescriptorLength(NewSecurityDescriptor); - FileSecurity = (PSECURITY_DESCRIPTOR)malloc(FileSecuritySize); - if (0 == FileSecurity) - { - FspDeleteSecurityDescriptor(NewSecurityDescriptor, (NTSTATUS(*)())FspSetSecurityDescriptor); - return STATUS_INSUFFICIENT_RESOURCES; - } - memcpy(FileSecurity, NewSecurityDescriptor, FileSecuritySize); - FspDeleteSecurityDescriptor(NewSecurityDescriptor, (NTSTATUS(*)())FspSetSecurityDescriptor); - - free(FileNode->FileSecurity); - FileNode->FileSecuritySize = FileSecuritySize; - FileNode->FileSecurity = FileSecurity; - - return STATUS_SUCCESS; + return STATUS_ACCESS_DENIED; } typedef struct _MEMFS_READ_DIRECTORY_CONTEXT @@ -991,6 +925,11 @@ NTSTATUS MemfsCreateFunnel( ULONG MaxFileNodes, PWSTR FileSystemName, PWSTR VolumePrefix, + PWSTR VolumeLabel, + UINT64 VolumeTotal, + UINT64 VolumeFree, + PVOID Security, + SIZE_T SecuritySize, MEMFS_FILE_PROVIDER* FileProvider, MEMFS** PMemfs) { @@ -1013,8 +952,10 @@ NTSTATUS MemfsCreateFunnel( memset(Memfs, 0, sizeof * Memfs); Memfs->MaxFileNodes = MaxFileNodes; - Memfs->FileProvider = FileProvider; + Memfs->Security = malloc(SecuritySize); + memcpy(Memfs->Security, Security, SecuritySize); + Memfs->SecuritySize = SecuritySize; Result = MemfsFileNodeMapCreate(CaseInsensitive, &Memfs->FileNodeMap); if (!NT_SUCCESS(Result)) @@ -1040,9 +981,8 @@ NTSTATUS MemfsCreateFunnel( VolumeParams.FlushAndPurgeOnCleanup = false; VolumeParams.AllowOpenInKernelMode = 1; if (0 != VolumePrefix) - wcscpy_s(VolumeParams.Prefix, sizeof VolumeParams.Prefix / sizeof(WCHAR), VolumePrefix); - wcscpy_s(VolumeParams.FileSystemName, sizeof VolumeParams.FileSystemName / sizeof(WCHAR), - 0 != FileSystemName ? FileSystemName : L"-MEMFS"); + wcscpy_s(VolumeParams.Prefix, VolumePrefix); + wcscpy_s(VolumeParams.FileSystemName, FileSystemName); Result = FspFileSystemCreate(DevicePath, &VolumeParams, &MemfsInterface, &Memfs->FileSystem); if (!NT_SUCCESS(Result)) @@ -1053,8 +993,10 @@ NTSTATUS MemfsCreateFunnel( } Memfs->FileSystem->UserContext = Memfs; - Memfs->VolumeLabelLength = sizeof L"MEMFS" - sizeof(WCHAR); - memcpy(Memfs->VolumeLabel, L"MEMFS", Memfs->VolumeLabelLength); + + wcscpy_s(Memfs->VolumeLabel, VolumeLabel); + Memfs->VolumeTotal = VolumeTotal; + Memfs->VolumeFree = VolumeFree; /* * Create root directory. @@ -1063,7 +1005,6 @@ NTSTATUS MemfsCreateFunnel( Result = MemfsFileNodeCreate(Memfs->FileProvider, L"\\", &RootNode); if (!NT_SUCCESS(Result)) { - wprintf(L"FAILED TO MAKE FILE %08x\n", Result); MemfsDelete(Memfs); return Result; } diff --git a/filesystem/memfs.h b/filesystem/memfs.h new file mode 100644 index 0000000..c0d550f --- /dev/null +++ b/filesystem/memfs.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#include + +#define MEMFS_MAX_PATH 512 +FSP_FSCTL_STATIC_ASSERT(MEMFS_MAX_PATH > MAX_PATH, + "MEMFS_MAX_PATH must be greater than MAX_PATH."); + +#define MEMFS_SECTOR_SIZE 512 +#define MEMFS_SECTORS_PER_ALLOCATION_UNIT 1 + +typedef struct _MEMFS MEMFS; + +typedef struct _MEMFS_FILE_PROVIDER MEMFS_FILE_PROVIDER; + +enum +{ + MemfsDisk = 0x00000000, + MemfsNet = 0x00000001, + MemfsDeviceMask = 0x0000000f, + MemfsCaseInsensitive = 0x80000000, +}; + +NTSTATUS MemfsCreateFunnel( + ULONG Flags, + ULONG FileInfoTimeout, + ULONG MaxFileNodes, + PWSTR FileSystemName, + PWSTR VolumePrefix, + PWSTR VolumeLabel, + UINT64 VolumeTotal, + UINT64 VolumeFree, + PVOID Security, + SIZE_T SecuritySize, + MEMFS_FILE_PROVIDER* FileProvider, + MEMFS** PMemfs); +VOID MemfsDelete(MEMFS* Memfs); +NTSTATUS MemfsStart(MEMFS* Memfs); +VOID MemfsStop(MEMFS* Memfs); +FSP_FILE_SYSTEM* MemfsFileSystem(MEMFS* Memfs); + +MEMFS_FILE_PROVIDER* CreateProvider( + std::function Open, + std::function Close, + std::function Read); +void CloseProvider(MEMFS_FILE_PROVIDER* Provider); + +NTSTATUS CreateFsFile(MEMFS* Memfs, + PWSTR FileName, BOOLEAN Directory); + + diff --git a/gui/cApp.cpp b/gui/cApp.cpp new file mode 100644 index 0000000..eced918 --- /dev/null +++ b/gui/cApp.cpp @@ -0,0 +1,60 @@ +#include "cApp.h" + +#include "../winfspcheck.h" + +#include + +#define MESSAGE_ERROR(format, ...) wxMessageBox(wxString::Format(format, __VA_ARGS__), "Error", wxICON_ERROR | wxOK | wxCENTRE) + +cApp::cApp() { + +} + +cApp::~cApp() { + +} + +bool cApp::OnInit() { + auto result = LoadWinFsp(); + if (result != WinFspCheckResult::LOADED) { + switch (result) + { + case WinFspCheckResult::CANNOT_ENUMERATE: + MESSAGE_ERROR("Could not iterate over drivers to get WinFsp install. System-specific error: %d", GetLastError()); + break; + case WinFspCheckResult::NOT_FOUND: + MESSAGE_ERROR("Could not find WinFsp as an installed driver. Maybe you don't have it installed?"); + break; + case WinFspCheckResult::NO_DLL: + MESSAGE_ERROR("Could not find WinFsp's DLL in the driver's folder. Try reinstalling WinFsp."); + break; + case WinFspCheckResult::CANNOT_LOAD: + MESSAGE_ERROR("Could not load WinFsp's DLL in the driver's folder. Try reinstalling WinFsp."); + break; + default: + MESSAGE_ERROR("An unknown error occurred when trying to load WinFsp's DLL: %d", result); + break; + } + return false; + } + + fs::path DataFolder; + { + PWSTR appDataFolder; + if (SHGetKnownFolderPath(FOLDERID_RoamingAppData, 0, NULL, &appDataFolder) != S_OK) { + MESSAGE_ERROR("Could not get the location of your AppData folder.", result); + return false; + } + DataFolder = appDataFolder; + CoTaskMemFree(appDataFolder); + } + DataFolder /= "EGL2"; + if (!fs::create_directories(DataFolder) && !fs::is_directory(DataFolder)) { + MESSAGE_ERROR("Could not create EGL2 folder.", result); + return false; + } + + m_frame1 = new cMain(DataFolder / "config", DataFolder); + m_frame1->Show(); + return true; +} \ No newline at end of file diff --git a/gui/cApp.h b/gui/cApp.h new file mode 100644 index 0000000..fae44e0 --- /dev/null +++ b/gui/cApp.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "cMain.h" + +class cApp : public wxApp +{ +public: + cApp(); + ~cApp(); + +private: + cMain* m_frame1 = nullptr; + +public: + virtual bool OnInit(); +}; + diff --git a/gui/cAuth.cpp b/gui/cAuth.cpp new file mode 100644 index 0000000..7669fd2 --- /dev/null +++ b/gui/cAuth.cpp @@ -0,0 +1,78 @@ +#include "cAuth.h" + +#define EXCHANGE_INFO "This exchange code is used to log you into Fortnite. Without it, the " \ + "game doesn't know you. To get this code, press the get code button " \ + "above. Copy the giant blob of numbers and letters, not including the " \ + "quotes. Paste it into the text box above and click launch.\n "\ + "Example exchange code: d9c230c0e0354a619249ba1156df5e63" + +#define EXCHANGE_SEC "This code expires 5 minutes after visiting the URL and can only be " \ + "used once. Anyone with this code can impersonate you and take control " \ + "of your account. When using EGL2, this code is passed directly to the " \ + "game and is not read, used, stored, or modified for any purpose other " \ + "than to launch the game." + +#define EXCHANGE_URL "https://www.epicgames.com/id/login?redirectUrl=https%3A%2F%2Fwww.epicgames.com%2Fid%2Fapi%2Fexchange" + +#include + +cAuth::cAuth(cMain* main) : wxModalWindow(main, wxID_ANY, "Launch Game - EGL2", wxDefaultPosition, wxDefaultSize, wxDEFAULT_FRAME_STYLE ^ (wxMAXIMIZE_BOX | wxRESIZE_BORDER)) { + this->SetIcon(wxICON(APP_ICON)); + this->SetMinSize(wxSize(400, 330)); + this->SetMaxSize(wxSize(400, 330)); + + panel = new wxPanel(this, wxID_ANY, + wxDefaultPosition, + wxDefaultSize, + wxTAB_TRAVERSAL); + + codeLabel = new wxStaticText(panel, wxID_ANY, "Enter your exchange code", wxDefaultPosition, wxDefaultSize, wxALIGN_CENTRE_HORIZONTAL); + codeInput = new wxTextCtrl(panel, wxID_ANY); + codeLink = new wxButton(panel, wxID_ANY, "Get Code"); + codeSubmit = new wxButton(panel, wxID_ANY, "Launch"); + codeInfo = new wxStaticText(panel, wxID_ANY, EXCHANGE_INFO); + codeSecBox = new wxStaticBoxSizer(wxVERTICAL, panel, "Security Info"); + codeSecInfo = new wxStaticText(panel, wxID_ANY, EXCHANGE_SEC); + codeSecBox->Add(codeSecInfo, 1, wxEXPAND); + + auto sizerBtns = new wxBoxSizer(wxHORIZONTAL); + + sizerBtns->Add(codeLink, 1, wxEXPAND | wxRIGHT, 3); + sizerBtns->Add(codeSubmit, 1, wxEXPAND | wxLEFT, 3); + + auto sizer = new wxBoxSizer(wxVERTICAL); + + sizer->Add(codeLabel, 0, wxEXPAND); + sizer->Add(codeInput, 0, wxEXPAND | wxUP | wxDOWN, 5); + sizer->Add(sizerBtns, 0, wxEXPAND | wxDOWN, 5); + sizer->Add(codeInfo, 3, wxEXPAND); + sizer->Add(codeSecBox, 4, wxEXPAND); + + auto topSizer = new wxBoxSizer(wxVERTICAL); + topSizer->Add(sizer, wxSizerFlags(1).Expand().Border(wxALL, 5)); + panel->SetSizerAndFit(topSizer); + this->SetSize(wxSize(400, 330)); + + codeLink->Bind(wxEVT_BUTTON, &cAuth::OnGetCodeClicked, this); + codeSubmit->Bind(wxEVT_BUTTON, &cAuth::OnSubmitClicked, this); +} + +cAuth::~cAuth() { + +} + +wxString& cAuth::GetCode() { + return returnValue; +} + +void cAuth::OnGetCodeClicked(wxCommandEvent& evt) { + wxLaunchDefaultBrowser(EXCHANGE_URL); +} + +void cAuth::OnSubmitClicked(wxCommandEvent& evt) { + if (codeInput->GetValue().IsEmpty()) { + return; + } + returnValue = codeInput->GetValue(); + Close(); +} \ No newline at end of file diff --git a/gui/cAuth.h b/gui/cAuth.h new file mode 100644 index 0000000..de9ed02 --- /dev/null +++ b/gui/cAuth.h @@ -0,0 +1,35 @@ +#pragma once + +#include "cMain.h" + +#include "wxModalWindow.h" + +#include +#include + + +class cAuth : public wxModalWindow +{ +public: + cAuth(cMain* main); + ~cAuth(); + + wxString& GetCode(); + +protected: + wxPanel* panel = nullptr; + + wxStaticText* codeLabel = nullptr; + wxTextCtrl* codeInput = nullptr; + wxButton* codeLink = nullptr; + wxButton* codeSubmit = nullptr; + wxStaticText* codeInfo = nullptr; + wxStaticBoxSizer* codeSecBox = nullptr; + wxStaticText* codeSecInfo = nullptr; + + wxString returnValue; + + void OnGetCodeClicked(wxCommandEvent& evt); + void OnSubmitClicked(wxCommandEvent& evt); +}; + diff --git a/gui/cMain.cpp b/gui/cMain.cpp new file mode 100644 index 0000000..fae71fa --- /dev/null +++ b/gui/cMain.cpp @@ -0,0 +1,248 @@ +#include "cMain.h" +#include "cSetup.h" +#include "cProgress.h" +#include "cAuth.h" + +#include +#include +#include +#include + +#define DESC_TEXT_DEFAULT "Hover over a button to see what it does" +#define DESC_TEXT_SETUP "Click this before running for the first time. Has options for setting up your installation." +#define DESC_TEXT_VERIFY "If you believe some chunks are invalid, click this to verify all chunks that are already downloaded. Redownloads any invalid chunks." +#define DESC_TEXT_PURGE "Deletes any chunks that aren't used anymore. Useful for slimming down your install after you've updated your game." +#define DESC_TEXT_PRELOAD "Downloads any chunks that you don't already have downloaded for the most recent version, a.k.a updating. It's reccomended to update your install before playing in case you have a hotfix available." +#define DESC_TEXT_START "Mount your install to a drive letter, and, if selected, move the files necessary to a folder in order to start your game." +#define DESC_TEXT_PLAY "When clicked, it will prompt you to provide your exchange code. After giving it, the game will launch and you'll be good to go." \ + "\n\nThis option is only available when playing is enabled." + +#define STATUS_NEED_SETUP "Setup where you want your game to be installed." +#define STATUS_NORMAL "Click start to mount." +#define STATUS_PLAYABLE "Started! Press \"Play\" to start playing!" +#define STATUS_UNPLAYABLE "Started! If you want to play, enable it in your setup!" + +#define LAUNCH_GAME_ARGS "-AUTH_LOGIN=unused AUTH_TYPE=exchangecode -epicapp=Fortnite -epicenv=Prod -epicportal -epiclocale=en-us -AUTH_PASSWORD=%s" + +#define BIND_BUTTON_DESC(btn, desc) \ + btn->Bind(wxEVT_MOTION, std::bind(&cMain::OnButtonHover, this, desc)); \ + btn->Bind(wxEVT_LEAVE_WINDOW, std::bind(&cMain::OnButtonHover, this, DESC_TEXT_DEFAULT)); + +cMain::cMain(fs::path settingsPath, fs::path manifestPath) : wxFrame(nullptr, wxID_ANY, "EGL2", wxDefaultPosition, wxDefaultSize, wxDEFAULT_FRAME_STYLE ^ (wxMAXIMIZE_BOX | wxRESIZE_BORDER)) { + this->SetIcon(wxICON(APP_ICON)); + this->SetMinSize(wxSize(450, 250)); + this->SetMaxSize(wxSize(450, 250)); + + panel = new wxPanel(this, wxID_ANY, + wxDefaultPosition, + wxDefaultSize, + wxTAB_TRAVERSAL); + + auto grid = new wxGridBagSizer(); + + setupBtn = new wxButton(panel, wxID_ANY, "Setup"); + verifyBtn = new wxButton(panel, wxID_ANY, "Verify"); + purgeBtn = new wxButton(panel, wxID_ANY, "Purge"); + preloadBtn = new wxButton(panel, wxID_ANY, "Update"); + startBtn = new wxButton(panel, wxID_ANY, "Start"); + startFnBtn = new wxButton(panel, wxID_ANY, "Play"); + statusBar = new wxStaticText(panel, wxID_ANY, wxEmptyString); + selloutBar = new wxStaticText(panel, wxID_ANY, "Use code \"furry\"! (#ad)"); + + setupBtn->Bind(wxEVT_BUTTON, std::bind(&cMain::OnSetupClicked, this)); + verifyBtn->Bind(wxEVT_BUTTON, std::bind(&cMain::OnVerifyClicked, this)); + purgeBtn->Bind(wxEVT_BUTTON, std::bind(&cMain::OnPurgeClicked, this)); + preloadBtn->Bind(wxEVT_BUTTON, std::bind(&cMain::OnPreloadClicked, this)); + startBtn->Bind(wxEVT_BUTTON, std::bind(&cMain::OnStartClicked, this)); + startFnBtn->Bind(wxEVT_BUTTON, std::bind(&cMain::OnPlayClicked, this)); + startFnBtn->Disable(); + + BIND_BUTTON_DESC(setupBtn, DESC_TEXT_SETUP); + BIND_BUTTON_DESC(verifyBtn, DESC_TEXT_VERIFY); + BIND_BUTTON_DESC(purgeBtn, DESC_TEXT_PURGE); + BIND_BUTTON_DESC(preloadBtn, DESC_TEXT_PRELOAD); + BIND_BUTTON_DESC(startBtn, DESC_TEXT_START); + BIND_BUTTON_DESC(startFnBtn, DESC_TEXT_PLAY); + + descBox = new wxStaticBoxSizer(wxVERTICAL, panel, "Description"); + descTxt = new wxStaticText(panel, wxID_ANY, DESC_TEXT_DEFAULT); + descBox->Add(descTxt, 1, wxEXPAND); + + auto barSizer = new wxBoxSizer(wxHORIZONTAL); + barSizer->Add(statusBar); + barSizer->AddStretchSpacer(); + barSizer->Add(selloutBar); + + grid->Add(setupBtn, wxGBPosition(0, 0), wxGBSpan(1, 2), wxEXPAND); + grid->Add(verifyBtn, wxGBPosition(1, 0), wxGBSpan(1, 2), wxEXPAND); + grid->Add(purgeBtn, wxGBPosition(2, 0), wxGBSpan(1, 2), wxEXPAND); + grid->Add(preloadBtn, wxGBPosition(3, 0), wxGBSpan(1, 2), wxEXPAND); + grid->Add(startBtn, wxGBPosition(4, 0), wxGBSpan(1, 1), wxEXPAND); + grid->Add(startFnBtn, wxGBPosition(4, 1), wxGBSpan(1, 1), wxEXPAND); + grid->Add(descBox, wxGBPosition(0, 2), wxGBSpan(5, 1), wxEXPAND); + grid->Add(barSizer, wxGBPosition(5, 0), wxGBSpan(1, 3), wxEXPAND); + + grid->AddGrowableCol(0, 1); + grid->AddGrowableCol(1, 1); + grid->AddGrowableCol(2, 4); + + grid->AddGrowableRow(0); + grid->AddGrowableRow(1); + grid->AddGrowableRow(2); + grid->AddGrowableRow(3); + grid->AddGrowableRow(4); + + auto topSizer = new wxBoxSizer(wxVERTICAL); + topSizer->Add(grid, wxSizerFlags(1).Expand().Border(wxALL, 5)); + panel->SetSizerAndFit(topSizer); + this->SetSize(wxSize(450, 250)); + + Bind(wxEVT_CLOSE_WINDOW, &cMain::OnClose, this); + + SettingsPath = settingsPath; + memset(&Settings, 0, sizeof(Settings)); + auto settingsFp = fopen(SettingsPath.string().c_str(), "rb"); + if (settingsFp) { + SettingsRead(&Settings, settingsFp); + fclose(settingsFp); + } + else { + Settings.MountDrive = '\0'; + Settings.CompressionLevel = 4; // Slowest + Settings.CompressionMethod = 1; // Decompress + Settings.EnableGaming = true; + Settings.VerifyCache = true; + } + + { + auto valid = SettingsValidate(&Settings); + verifyBtn->Enable(valid); + purgeBtn->Enable(valid); + preloadBtn->Enable(valid); + startBtn->Enable(valid); + SetStatus(valid ? STATUS_NORMAL : STATUS_NEED_SETUP); + } + + ManifestAuthGrab(&ManifestAuth); + ManifestAuthGetManifest(ManifestAuth, manifestPath, &Manifest); +} + +cMain::~cMain() { + auto settingsFp = fopen(SettingsPath.string().c_str(), "wb"); + SettingsWrite(&Settings, settingsFp); + fclose(settingsFp); +} + +void cMain::OnButtonHover(const char* string) { + if (strcmp(descTxt->GetLabel().c_str(), string)) { + descTxt->SetLabel(string); + descBox->Fit(descTxt); + descBox->FitInside(descTxt); + descBox->Layout(); + } +} + +void cMain::OnSetupClicked() { + cSetup(this, &Settings).ShowModal(); + auto valid = SettingsValidate(&Settings); + verifyBtn->Enable(valid); + purgeBtn->Enable(valid); + preloadBtn->Enable(valid); + startBtn->Enable(valid); + SetStatus(valid ? STATUS_NORMAL : STATUS_NEED_SETUP); + if (valid) { + auto& build = GetMountedBuild(); + build->SetupCacheDirectory(); + } +} + +#define RUN_PROGRESS(taskName, funcName, ...) \ +{ \ + auto cancelled = new cancel_flag(); \ + cProgress* progress = new cProgress(this, taskName, \ + *cancelled); \ + progress->Show(true); \ + \ + wxWindowPtr progressPtr(progress); \ + std::thread([=]() { \ + auto& b = GetMountedBuild(); \ + \ + b->##funcName( \ + [=](uint32_t m) { progressPtr->SetMaximum(m); }, \ + [=]() { progressPtr->Increment(); }, \ + *cancelled, __VA_ARGS__); \ + \ + progressPtr->Finish(); \ + progressPtr->Close(); \ + delete cancelled; \ + }).detach(); \ +} + +void cMain::OnVerifyClicked() { + RUN_PROGRESS("Verifying", VerifyAllChunks, 64); +} + +void cMain::OnPurgeClicked() { + RUN_PROGRESS("Purging", PurgeUnusedChunks); +} + +void cMain::OnPreloadClicked() { + RUN_PROGRESS("Updating", PreloadAllChunks, 64); +} + +void cMain::OnStartClicked() { + auto& build = GetMountedBuild(); + if (build->Mounted()) { + build->Unmount(); + setupBtn->Enable(); + startFnBtn->Disable(); + startBtn->SetLabel("Start"); + SetStatus(STATUS_NORMAL); + } + else { + if (build->Mount()) { + if (Settings.EnableGaming) { + RUN_PROGRESS("Setting Up", SetupGameDirectory, 64, Settings.GameDir); + } + + setupBtn->Disable(); + startFnBtn->Enable(Settings.EnableGaming); + startBtn->SetLabel("Stop"); + SetStatus(Settings.EnableGaming ? STATUS_PLAYABLE : STATUS_UNPLAYABLE); + } + } +} + +void cMain::OnPlayClicked() { + cAuth auth(this); + auth.ShowModal(); + if (!auth.GetCode().IsEmpty()) { + auto& build = GetMountedBuild(); + build->LaunchGame(Settings.GameDir, wxString::Format(LAUNCH_GAME_ARGS, auth.GetCode()).c_str()); + } +} + +void cMain::OnClose(wxCloseEvent& evt) +{ + if (evt.CanVeto() && Build && Build->Mounted()) { + if (wxMessageBox("Your game is currently mounted. Do you want to exit now?", "Currently Mounted - EGL2", wxICON_QUESTION | wxYES_NO) != wxYES) + { + evt.Veto(); + return; + } + } + + Destroy(); +} + +void cMain::SetStatus(const char* string) { + statusBar->SetLabel(string); +} + +std::unique_ptr& cMain::GetMountedBuild() { + if (!Build) { + Build = std::make_unique(Manifest, std::string(1, Settings.MountDrive) + ':', Settings.CacheDir, [](const char* error) {}); + Build->StartStorage(SettingsGetStorageFlags(&Settings)); + } + return Build; +} \ No newline at end of file diff --git a/gui/cMain.h b/gui/cMain.h new file mode 100644 index 0000000..d71b103 --- /dev/null +++ b/gui/cMain.h @@ -0,0 +1,53 @@ +#pragma once + +#include "../MountedBuild.h" +#include "settings.h" +#include + +class cMain : public wxFrame +{ +public: + cMain(fs::path settingsPath, fs::path manifestPath); + ~cMain(); + +protected: + wxPanel* panel = nullptr; + + wxButton* setupBtn = nullptr; + wxButton* verifyBtn = nullptr; + wxButton* purgeBtn = nullptr; + wxButton* preloadBtn = nullptr; + wxButton* startBtn = nullptr; + wxButton* startFnBtn = nullptr; + + wxStaticBoxSizer* descBox = nullptr; + wxStaticText* descTxt = nullptr; + + wxStaticText* statusBar = nullptr; + wxStaticText* selloutBar = nullptr; + + void OnButtonHover(const char* string); + + void OnSetupClicked(); + void OnVerifyClicked(); + void OnPurgeClicked(); + void OnPreloadClicked(); + void OnStartClicked(); + void OnPlayClicked(); + + void OnClose(wxCloseEvent& evt); + + void SetStatus(const char* string); + +private: + std::unique_ptr& GetMountedBuild(); + + fs::path SettingsPath; + SETTINGS Settings; + + std::unique_ptr Build; + + MANIFEST_AUTH* ManifestAuth; + MANIFEST* Manifest; +}; + diff --git a/gui/cProgress.cpp b/gui/cProgress.cpp new file mode 100644 index 0000000..231ba42 --- /dev/null +++ b/gui/cProgress.cpp @@ -0,0 +1,122 @@ +#include "cProgress.h" +#include + +namespace ch = std::chrono; + +cProgress::cProgress(cMain* main, wxString taskName, cancel_flag& cancelFlag, float updateFreq, uint32_t maximum) : wxFrame(main, wxID_ANY, taskName + " - EGL2", wxDefaultPosition, wxDefaultSize, wxDEFAULT_FRAME_STYLE ^ (wxMAXIMIZE_BOX | wxRESIZE_BORDER)) { + value = 0; + frequency = updateFreq; + maxValue = maximum; + startTime = Clock::now(); + + this->SetIcon(wxICON(APP_ICON)); + this->SetMinSize(wxSize(400, -1)); + this->SetMaxSize(wxSize(400, -1)); + + panel = new wxPanel(this, wxID_ANY, + wxDefaultPosition, + wxDefaultSize, + wxTAB_TRAVERSAL); + + progressBar = new wxGauge(panel, wxID_ANY, maxValue, wxDefaultPosition, wxDefaultSize, wxGA_HORIZONTAL | wxGA_SMOOTH); + + progressPercent = new wxStaticText(panel, wxID_ANY, wxEmptyString, wxDefaultPosition, wxDefaultSize, wxALIGN_CENTRE_HORIZONTAL | wxST_NO_AUTORESIZE); + progressTotal = new wxStaticText(panel, wxID_ANY, wxEmptyString, wxDefaultPosition, wxDefaultSize, wxALIGN_CENTRE_HORIZONTAL | wxST_NO_AUTORESIZE); + progressTimeElapsed = new wxStaticText(panel, wxID_ANY, wxEmptyString, wxDefaultPosition, wxDefaultSize, wxALIGN_CENTRE_HORIZONTAL | wxST_NO_AUTORESIZE); + progressTimeETA = new wxStaticText(panel, wxID_ANY, wxEmptyString, wxDefaultPosition, wxDefaultSize, wxALIGN_CENTRE_HORIZONTAL | wxST_NO_AUTORESIZE); + + progressCancelBtn = new wxButton(panel, wxID_ANY, "Cancel"); + progressTaskbar = new wxAppProgressIndicator(this, maxValue); + progressDisabler = new wxWindowDisabler(this); + + progressTextSizer = new wxGridSizer(2, 2, 5, 5); + progressTextSizer->Add(progressPercent, 1, wxEXPAND); + progressTextSizer->Add(progressTimeElapsed, 1, wxEXPAND); + progressTextSizer->Add(progressTotal, 1, wxEXPAND); + progressTextSizer->Add(progressTimeETA, 1, wxEXPAND); + + auto sizer = new wxBoxSizer(wxVERTICAL); + + sizer->Add(progressBar, 0, wxEXPAND); + sizer->Add(progressTextSizer, 0, wxEXPAND | wxUP | wxDOWN, 5); + sizer->Add(progressCancelBtn, 0, wxEXPAND); + + auto topSizer = new wxBoxSizer(wxVERTICAL); + topSizer->Add(sizer, wxSizerFlags(1).Expand().Border(wxALL, 5)); + panel->SetSizerAndFit(topSizer); + this->Fit(); + + progressCancelBtn->Bind(wxEVT_BUTTON, [this, &cancelFlag](wxCommandEvent& evt) { Cancel(cancelFlag); }); + Bind(wxEVT_CLOSE_WINDOW, [this, &cancelFlag](wxCloseEvent& evt) { Cancel(cancelFlag); }); +} + +cProgress::~cProgress() { + delete progressTaskbar; + delete progressDisabler; +} + +inline void cProgress::Cancel(cancel_flag& cancelFlag) { + cancelFlag.cancel(); + progressCancelBtn->Disable(); + progressCancelBtn->SetLabel("Cancelling"); +} + +inline wxString FormatTime(ch::seconds secs) { + auto mins = ch::duration_cast(secs); + secs -= ch::duration_cast(mins); + auto hour = ch::duration_cast(mins); + mins -= ch::duration_cast(hour); + + return wxString::Format("%02d:%02d:%02d", hour.count(), mins.count(), int(secs.count())); +} + +inline void cProgress::Update(bool force) { + auto now = Clock::now(); + ch::duration duration = now - lastUpdate; + if (duration.count() < frequency || force) { + return; + } + lastUpdate = now; + progressBar->SetValue(value + 1); + progressBar->SetValue(value); + progressTaskbar->SetValue(value); + + { + progressPercent->SetLabel(wxString::Format("%.2f%%", float(value) * 100 / maxValue)); + progressTotal->SetLabel(wxString::Format("%u / %u", value.load(), maxValue)); + + auto elapsed = ch::duration_cast(now - startTime); + progressTimeElapsed->SetLabel("Elapsed: " + FormatTime(elapsed)); + + auto etaDivisor = float(maxValue - value) / value; + auto eta = etaDivisor ? ch::duration_cast(elapsed * etaDivisor) : ch::seconds::zero(); + progressTimeETA->SetLabel("ETA: " + FormatTime(eta)); + } +} + +void cProgress::SetFrequency(float updateFreq) { + frequency = updateFreq; +} + +void cProgress::SetMaximum(uint32_t maximum) { + if (maximum != maxValue) { + maxValue = maximum; + progressBar->SetRange(maxValue); + progressTaskbar->SetRange(maxValue); + Update(); + } +} + +void cProgress::Increment() { + if (finished) { + return; + } + value++; + Update(value == maxValue); +} + +void cProgress::Finish() { + value = maxValue; + finished = true; + Update(true); +} \ No newline at end of file diff --git a/gui/cProgress.h b/gui/cProgress.h new file mode 100644 index 0000000..be44023 --- /dev/null +++ b/gui/cProgress.h @@ -0,0 +1,47 @@ +#pragma once + +#include "cMain.h" +#include "../containers/cancel_flag.h" + +#include + +class cProgress : public wxFrame +{ + typedef std::chrono::steady_clock Clock; + +public: + cProgress(cMain* main, wxString taskName, cancel_flag& cancelFlag, float updateFreq = .05f, uint32_t maximum = 1); + ~cProgress(); + + void SetFrequency(float updateFreq); + void SetMaximum(uint32_t maximum); + void Increment(); + void Finish(); + +protected: + wxPanel* panel = nullptr; + + wxGauge* progressBar = nullptr; + wxSizer* progressTextSizer = nullptr; + wxStaticText* progressPercent = nullptr; + wxStaticText* progressTotal = nullptr; + wxStaticText* progressTimeElapsed = nullptr; + wxStaticText* progressTimeETA = nullptr; + wxButton* progressCancelBtn = nullptr; + + wxAppProgressIndicator* progressTaskbar = nullptr; + wxWindowDisabler* progressDisabler = nullptr; + +private: + void Cancel(cancel_flag& cancelFlag); + + void Update(bool force = false); + + bool finished = false; + Clock::time_point startTime; + Clock::time_point lastUpdate; + float frequency; + std::atomic_uint32_t value; + uint32_t maxValue; +}; + diff --git a/gui/cSetup.cpp b/gui/cSetup.cpp new file mode 100644 index 0000000..da6460f --- /dev/null +++ b/gui/cSetup.cpp @@ -0,0 +1,160 @@ +#include "cSetup.h" + +#include + +#include "settings.h" + +static const char* compMethods[] = { + "Keep Compressed as Downloaded (zLib)", + "Decompress", + "Use LZ4", + "Use zLib" +}; + +static const char* compLevels[] = { + "Fastest", + "Fast", + "Normal", + "Slow", + "Slowest" +}; + +#define VERIFY_TOOLTIP "Verify data that is read from the cache and redownload\n" \ + "it if the data is invalid.\n" \ + "Note: You may take a small performance hit." + +#define GAME_TOOLTIP "In order to launch the game, a workaround must be done\n" \ + "where all binaries are copied to a physical drive in order to\n" \ + "prevent the anticheat from getting grumpy.\n" \ + "Note: Depending on the install, an additional 300-400MB\n" \ + "of data will need to be allocated on your hard drive." + +#define DRIVE_ALPHABET "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +cSetup::cSetup(cMain* main, SETTINGS* settings) : wxModalWindow(main, wxID_ANY, "Setup - EGL2", wxDefaultPosition, wxDefaultSize, wxDEFAULT_FRAME_STYLE ^ (wxMAXIMIZE_BOX | wxRESIZE_BORDER)) { + Settings = settings; + + this->SetIcon(wxICON(APP_ICON)); + this->SetMinSize(wxSize(500, -1)); + this->SetMaxSize(wxSize(500, -1)); + + wxPanel* panel = new wxPanel(this, wxID_ANY, + wxDefaultPosition, + wxDefaultSize, + wxTAB_TRAVERSAL); + + settingsContainer = new wxWindow(panel, wxID_ANY); + checkboxContainer = new wxWindow(panel, wxID_ANY); + + auto settingsGrid = new wxGridBagSizer(2, 2); + + cacheDirTxt = new wxStaticText(settingsContainer, wxID_ANY, "Cache Folder"); + cacheDirValue = new wxDirPickerCtrl(settingsContainer, wxID_ANY); + + wxArrayString drives; + auto usedDrives = GetLogicalDrives(); + for (int b = 0; b < strlen(DRIVE_ALPHABET); ++b) { + if (!((usedDrives >> b) & 1)) { + drives.push_back(wxString(DRIVE_ALPHABET[b]) + ':'); + } + } + + mountDirTxt = new wxStaticText(settingsContainer, wxID_ANY, "Mount Drive"); + mountDirValue = new wxChoice(settingsContainer, wxID_ANY, wxDefaultPosition, wxDefaultSize, drives); + + gameDirTxt = new wxStaticText(settingsContainer, wxID_ANY, "Game Folder"); + gameDirValue = new wxDirPickerCtrl(settingsContainer, wxID_ANY); + + compMethodTxt = new wxStaticText(settingsContainer, wxID_ANY, "Compression Method"); + compMethodValue = new wxChoice(settingsContainer, wxID_ANY); + compMethodValue->Append(wxArrayString(_countof(compMethods), compMethods)); + + compLevelTxt = new wxStaticText(settingsContainer, wxID_ANY, "Compression Level"); + compLevelValue = new wxChoice(settingsContainer, wxID_ANY); + compLevelValue->Append(wxArrayString(_countof(compLevels), compLevels)); + + + verifyCacheCheckbox = new wxCheckBox(checkboxContainer, wxID_ANY, "Verify cache when read from"); + verifyCacheCheckbox->SetToolTip(VERIFY_TOOLTIP); + + gameDirCheckbox = new wxCheckBox(checkboxContainer, wxID_ANY, "Enable playing of game"); + gameDirCheckbox->SetToolTip(GAME_TOOLTIP); + + settingsGrid->Add(cacheDirTxt, wxGBPosition(0, 0), wxGBSpan(1, 1), wxEXPAND); + settingsGrid->Add(cacheDirValue, wxGBPosition(0, 2), wxGBSpan(1, 1), wxEXPAND); + + settingsGrid->Add(mountDirTxt, wxGBPosition(1, 0), wxGBSpan(1, 1), wxEXPAND); + settingsGrid->Add(mountDirValue, wxGBPosition(1, 2), wxGBSpan(1, 1), wxEXPAND); + + settingsGrid->Add(gameDirTxt, wxGBPosition(2, 0), wxGBSpan(1, 1), wxEXPAND); + settingsGrid->Add(gameDirValue, wxGBPosition(2, 2), wxGBSpan(1, 1), wxEXPAND); + + settingsGrid->Add(compMethodTxt, wxGBPosition(3, 0), wxGBSpan(1, 1), wxEXPAND); + settingsGrid->Add(compLevelTxt, wxGBPosition(4, 0), wxGBSpan(1, 1), wxEXPAND); + settingsGrid->Add(compMethodValue, wxGBPosition(3, 2), wxGBSpan(1, 1), wxEXPAND); + settingsGrid->Add(compLevelValue, wxGBPosition(4, 2), wxGBSpan(1, 1), wxEXPAND); + + settingsGrid->AddGrowableCol(2); + settingsGrid->Add(5, 1, wxGBPosition(0, 1)); + + settingsContainer->SetSizerAndFit(settingsGrid); + + auto checkboxSizer = new wxBoxSizer(wxHORIZONTAL); + + checkboxSizer->Add(verifyCacheCheckbox, 1); + checkboxSizer->Add(gameDirCheckbox, 1); + + checkboxContainer->SetSizerAndFit(checkboxSizer); + + auto sizer = new wxBoxSizer(wxVERTICAL); + + sizer->Add(settingsContainer, wxSizerFlags().Expand().Border(wxUP | wxRIGHT | wxLEFT, 10)); + sizer->AddSpacer(10); + sizer->Add(checkboxContainer, wxSizerFlags().Expand().Border(wxDOWN | wxRIGHT | wxLEFT, 10)); + + panel->SetSizerAndFit(sizer); + this->Fit(); + + wxToolTip::SetAutoPop(15000); + + ReadConfig(); + if (Settings->MountDrive == '\0') { + mountDirValue->SetSelection(0); + } + // disable when first 2 options are selected + compLevelValue->Enable(compMethodValue->GetSelection() >= 2); + compMethodValue->Bind(wxEVT_CHOICE, [this](wxCommandEvent& evt) { + compLevelValue->Enable(compMethodValue->GetSelection() >= 2); + }); + Bind(wxEVT_CLOSE_WINDOW, std::bind(&cSetup::WriteConfig, this)); +} + +cSetup::~cSetup() { + +} + +void cSetup::ReadConfig() { + cacheDirValue->SetPath(Settings->CacheDir); + mountDirValue->SetStringSelection(wxString(Settings->MountDrive) + ':'); + gameDirValue->SetPath(Settings->GameDir); + + compMethodValue->SetSelection(Settings->CompressionMethod); + compLevelValue->SetSelection(Settings->CompressionLevel); + + verifyCacheCheckbox->SetValue(Settings->VerifyCache); + gameDirCheckbox->SetValue(Settings->EnableGaming); +} + +void cSetup::WriteConfig() { + strcpy_s(Settings->CacheDir, cacheDirValue->GetPath().c_str()); + Settings->MountDrive = mountDirValue->GetStringSelection()[0]; + strcpy_s(Settings->GameDir, gameDirValue->GetPath().c_str()); + + Settings->CompressionMethod = compMethodValue->GetSelection(); + Settings->CompressionLevel = compLevelValue->GetSelection(); + + Settings->VerifyCache = verifyCacheCheckbox->IsChecked(); + Settings->EnableGaming = gameDirCheckbox->IsChecked(); + + this->Destroy(); +} \ No newline at end of file diff --git a/gui/cSetup.h b/gui/cSetup.h new file mode 100644 index 0000000..c14e1c9 --- /dev/null +++ b/gui/cSetup.h @@ -0,0 +1,46 @@ +#pragma once + +#include "cMain.h" +#include "wxModalWindow.h" + +#include +#include + +class cSetup : public wxModalWindow +{ +public: + cSetup(cMain* main, SETTINGS* settings); + ~cSetup(); + +protected: + wxPanel* topPanel = nullptr; + + wxWindow* settingsContainer = nullptr; + wxWindow* checkboxContainer = nullptr; + + wxStaticText* cacheDirTxt = nullptr; + wxDirPickerCtrl* cacheDirValue = nullptr; + + wxStaticText* mountDirTxt = nullptr; + wxChoice* mountDirValue = nullptr; + // TODO: add folder support with a radio button/combo box + + wxStaticText* gameDirTxt = nullptr; + wxDirPickerCtrl* gameDirValue = nullptr; + + wxStaticText* compMethodTxt = nullptr; + wxChoice* compMethodValue = nullptr; + + wxStaticText* compLevelTxt = nullptr; + wxChoice* compLevelValue = nullptr; + + wxCheckBox* verifyCacheCheckbox = nullptr; + wxCheckBox* gameDirCheckbox = nullptr; + +private: + SETTINGS* Settings; + + void ReadConfig(); + void WriteConfig(); +}; + diff --git a/gui/guimain.cpp b/gui/guimain.cpp new file mode 100644 index 0000000..e4e711f --- /dev/null +++ b/gui/guimain.cpp @@ -0,0 +1,3 @@ +#include "cApp.h" + +wxIMPLEMENT_APP(cApp); \ No newline at end of file diff --git a/gui/resources.rc b/gui/resources.rc new file mode 100644 index 0000000..1d0e318 --- /dev/null +++ b/gui/resources.rc @@ -0,0 +1,27 @@ +APP_ICON ICON "../icon.ico" + +#include + +VS_VERSION_INFO VERSIONINFO +FILEVERSION 1, 0, 0, 0 +PRODUCTVERSION 1, 0, 0, 0 +{ + BLOCK "StringFileInfo" + { + BLOCK "040904b0" + { + VALUE "CompanyName", "WorkingRobot" + VALUE "FileDescription", "EGL2" + VALUE "FileVersion", "1.0.0.0" + VALUE "InternalName", "egl2" + VALUE "LegalCopyright", "EGL2 (c) by Aleks Margarian (WorkingRobot)" + VALUE "OriginalFilename", "EGL2.exe" + VALUE "ProductName", "EGL2" + VALUE "ProductVersion", "1.0.0.0" + } + } + BLOCK "VarFileInfo" + { + VALUE "Translation", 0x409, 1200 + } +} \ No newline at end of file diff --git a/gui/settings.cpp b/gui/settings.cpp new file mode 100644 index 0000000..4708f7e --- /dev/null +++ b/gui/settings.cpp @@ -0,0 +1,125 @@ +#include "settings.h" + +#include + +template +inline T ReadValue(FILE* File) { + char val[sizeof(T)]; + fread(val, sizeof(T), 1, File); + return *(T*)val; +} + +// ptrLength does not include the \0 at the end (ptrLength of 0 still includes \0) +inline bool ReadString(char* ptr, uint16_t ptrLength, FILE* File) { + auto stringSize = ntohs(ReadValue(File)); + if (stringSize > ptrLength) { + return false; + } + fread(ptr, 1, stringSize, File); + ptr[stringSize] = '\0'; + return true; +} + +bool SettingsRead(SETTINGS* Settings, FILE* File) { + rewind(File); + + if (ntohl(ReadValue(File)) != FILE_CONFIG_MAGIC) { + return false; + } + if (ntohs(ReadValue(File)) != FILE_CONFIG_VERSION) { + return false; + } + + ReadString(Settings->CacheDir, _MAX_PATH, File); + ReadString(Settings->GameDir, _MAX_PATH, File); + + Settings->MountDrive = ReadValue(File); + + Settings->CompressionMethod = ReadValue(File); + Settings->CompressionLevel = ReadValue(File); + + Settings->VerifyCache = ReadValue(File); + Settings->EnableGaming = ReadValue(File); + + return true; +} + +template +inline void WriteValue(T val, FILE* File) { + fwrite(&val, sizeof(T), 1, File); +} + +inline void WriteString(char* ptr, int ptrLength, FILE* File) { + WriteValue(htons(ptrLength), File); + fwrite(ptr, 1, ptrLength, File); +} + +void SettingsWrite(SETTINGS* Settings, FILE* File) +{ + rewind(File); + + WriteValue(htonl(FILE_CONFIG_MAGIC), File); + WriteValue(htons(FILE_CONFIG_VERSION), File); + + WriteString(Settings->CacheDir, strlen(Settings->CacheDir), File); + WriteString(Settings->GameDir, strlen(Settings->GameDir), File); + + WriteValue(Settings->MountDrive, File); + + WriteValue(Settings->CompressionMethod, File); + WriteValue(Settings->CompressionLevel, File); + + WriteValue(Settings->VerifyCache, File); + WriteValue(Settings->EnableGaming, File); +} + +bool SettingsValidate(SETTINGS* Settings) { + if (!fs::is_directory(Settings->CacheDir)) { + return false; + } + if (Settings->EnableGaming && !fs::is_directory(Settings->GameDir)) { + return false; + } + return true; +} + +uint32_t SettingsGetStorageFlags(SETTINGS* Settings) { + uint32_t StorageFlags = 0; + if (Settings->VerifyCache) { + StorageFlags |= StorageVerifyHashes; + } + switch (Settings->CompressionMethod) + { + case 0: + StorageFlags |= StorageCompressed; + break; + case 1: + StorageFlags |= StorageDecompressed; + break; + case 2: + StorageFlags |= StorageCompressLZ4; + break; + case 3: + StorageFlags |= StorageCompressZlib; + break; + } + switch (Settings->CompressionLevel) + { + case 0: + StorageFlags |= StorageCompressFastest; + break; + case 1: + StorageFlags |= StorageCompressFast; + break; + case 2: + StorageFlags |= StorageCompressNormal; + break; + case 3: + StorageFlags |= StorageCompressSlow; + break; + case 4: + StorageFlags |= StorageCompressSlowest; + break; + } + return StorageFlags; +} \ No newline at end of file diff --git a/gui/settings.h b/gui/settings.h new file mode 100644 index 0000000..c61fb7d --- /dev/null +++ b/gui/settings.h @@ -0,0 +1,22 @@ +#pragma once + +#include "../storage/storage.h" + +#define FILE_CONFIG_MAGIC 0xE6219B27 +#define FILE_CONFIG_VERSION 0 + +struct SETTINGS { + char CacheDir[_MAX_PATH + 1]; + char GameDir[_MAX_PATH + 1]; + char MountDrive; + uint8_t CompressionMethod; + uint8_t CompressionLevel; + bool VerifyCache; + bool EnableGaming; +}; + +bool SettingsRead(SETTINGS* Settings, FILE* File); +void SettingsWrite(SETTINGS* Settings, FILE* File); + +bool SettingsValidate(SETTINGS* Settings); +uint32_t SettingsGetStorageFlags(SETTINGS* Settings); \ No newline at end of file diff --git a/gui/wxModalWindow.cpp b/gui/wxModalWindow.cpp new file mode 100644 index 0000000..76fcb1c --- /dev/null +++ b/gui/wxModalWindow.cpp @@ -0,0 +1,138 @@ +// Taken from https://forums.wxwidgets.org/viewtopic.php?t=20752#p89334 + +/********************************************************************************************** + * + * Filename : modalwindow.cpp + * Purpose : Allow a modalwindow like wxDialog but allowing menus and such. + * Author : John A. Mason + * Created : 8/27/2008 07:54:12 AM + * Copyright : Released under wxWidgets original license. + * + **********************************************************************************************/ + +#include "wxModalWindow.h" + +wxModalWindow::wxModalWindow() + : wxFrame() { + Init(); +} + +wxModalWindow::wxModalWindow(wxWindow* parent, wxWindowID id, const wxString& title, const wxPoint& pos, const wxSize& size, long style, const wxString& name) + : wxFrame(parent, id, title, pos, size, style, name) { + Init(); +} + + +wxModalWindow::~wxModalWindow() { + delete m_eventLoop; +} + + +void wxModalWindow::Init() +{ + m_returnCode = 0; + m_windowDisabler = NULL; + m_eventLoop = NULL; + m_isShowingModal = false; +} + +bool wxModalWindow::Show(bool show) +{ + if (!show) + { + // if we had disabled other app windows, reenable them back now because + // if they stay disabled Windows will activate another window (one + // which is enabled, anyhow) and we will lose activation + if (m_windowDisabler) + { + delete m_windowDisabler; + m_windowDisabler = NULL; + } + + if (IsModal()) + EndModal(wxID_CANCEL); + } + + bool ret = wxFrame::Show(show); + + //I don't think we need this. Since it is a wxFrame that we are extending, + // we don't need wxEVT_INIT_DIALOG firing off - that's what InitDialog does... + // and this would only make sense if we have a wxDialog and validators +// if ( show ) + //InitDialog(); + + return ret; +} + +bool wxModalWindow::IsModal() const { + return m_isShowingModal; +} + +int wxModalWindow::ShowModal() { + if (IsModal()) + { + wxFAIL_MSG(wxT("wxModalWindow:ShowModal called twice")); + return GetReturnCode(); + } + + // use the apps top level window as parent if none given unless explicitly + // forbidden + if (!GetParent()) + { + wxWindow* parent = wxTheApp->GetTopWindow(); + if (parent && parent != this) + { + m_parent = parent; + } + } + + Show(true); + + m_isShowingModal = true; + + wxASSERT_MSG(!m_windowDisabler, _T("disabling windows twice?")); + +#if defined(__WXGTK__) || defined(__WXMGL__) + wxBusyCursorSuspender suspender; + // FIXME (FIXME_MGL) - make sure busy cursor disappears under MSW too +#endif + + m_windowDisabler = new wxWindowDisabler(this); + if (!m_eventLoop) + m_eventLoop = new wxEventLoop; + + m_eventLoop->Run(); + + return GetReturnCode(); +} + + + +void wxModalWindow::EndModal(int retCode) { + wxASSERT_MSG(m_eventLoop, _T("wxModalWindow is not modal")); + + SetReturnCode(retCode); + + if (!IsModal()) + { + wxFAIL_MSG(wxT("wxModalWindow:EndModal called twice")); + return; + } + + m_isShowingModal = false; + + m_eventLoop->Exit(); + + Show(false); +} + + +void wxModalWindow::SetReturnCode(int retCode) { + m_returnCode = retCode; +} + + +int wxModalWindow::GetReturnCode() const { + return m_returnCode; +} + diff --git a/gui/wxModalWindow.h b/gui/wxModalWindow.h new file mode 100644 index 0000000..3afd88d --- /dev/null +++ b/gui/wxModalWindow.h @@ -0,0 +1,61 @@ +// Taken from https://forums.wxwidgets.org/viewtopic.php?t=20752#p89334 + +/********************************************************************************************** + * + * Filename : modalwindow.h + * Purpose : Allow a modalwindow like wxDialog but allowing menus and such. + * Author : John A. Mason + * Created : 8/27/2008 07:54:12 AM + * Copyright : Released under wxWidgets original license. + * + **********************************************************************************************/ + +#ifndef __wx_ModalWindow_h__ +#define __wx_ModalWindow_h__ + +#include + +#ifdef __BORLANDC__ +#pragma hdrstop +#endif + +#include + +#ifndef WX_PRECOMP +#include +#include +#endif + +#include + +class wxModalWindow : public wxFrame { +private: + // while we are showing a modal window we disable the other windows using + // this object + wxWindowDisabler* m_windowDisabler; + + // modal window runs its own event loop + wxEventLoop* m_eventLoop; + + // is modal right now? + bool m_isShowingModal; + + //The return code of a modal window + int m_returnCode; +public: + wxModalWindow(); + wxModalWindow(wxWindow* parent, wxWindowID id, const wxString& title, const wxPoint& pos = wxDefaultPosition, const wxSize& size = wxDefaultSize, long style = wxDEFAULT_FRAME_STYLE, const wxString& name = "modalwindow"); + virtual ~wxModalWindow(); + + + void Init(); + bool Show(bool show); + bool IsModal() const; + int ShowModal(); + + void EndModal(int retCode); + void SetReturnCode(int retCode); + int GetReturnCode() const; +}; + +#endif diff --git a/icon.ico b/icon.ico new file mode 100644 index 0000000..0a0c80d Binary files /dev/null and b/icon.ico differ diff --git a/icon.svg b/icon.svg new file mode 100644 index 0000000..b387de3 --- /dev/null +++ b/icon.svg @@ -0,0 +1,108 @@ + + + EGL2 + + + + image/svg+xml + + EGL2 + + + + + + + + + + + + + + + + diff --git a/libdeflate/libdeflate.h b/libdeflate/libdeflate.h new file mode 100644 index 0000000..10fb187 --- /dev/null +++ b/libdeflate/libdeflate.h @@ -0,0 +1,328 @@ +/* + * libdeflate.h - public header for libdeflate + */ + +#ifndef LIBDEFLATE_H +#define LIBDEFLATE_H + +#ifdef __cplusplus +extern "C" { +#endif + +#define LIBDEFLATE_VERSION_MAJOR 1 +#define LIBDEFLATE_VERSION_MINOR 5 +#define LIBDEFLATE_VERSION_STRING "1.5" + +#include +#include + +/* + * On Windows, if you want to link to the DLL version of libdeflate, then + * #define LIBDEFLATE_DLL. Note that the calling convention is "stdcall". + */ +#ifdef LIBDEFLATE_DLL +# ifdef BUILDING_LIBDEFLATE +# define LIBDEFLATEEXPORT LIBEXPORT +# elif defined(_WIN32) || defined(__CYGWIN__) +# define LIBDEFLATEEXPORT __declspec(dllimport) +# endif +#endif +#ifndef LIBDEFLATEEXPORT +# define LIBDEFLATEEXPORT +#endif + +#if defined(_WIN32) && !defined(_WIN64) +# define LIBDEFLATEAPI_ABI __stdcall +#else +# define LIBDEFLATEAPI_ABI +#endif + +#if defined(BUILDING_LIBDEFLATE) && defined(__GNUC__) && \ + defined(_WIN32) && !defined(_WIN64) + /* + * On 32-bit Windows, gcc assumes 16-byte stack alignment but MSVC only 4. + * Realign the stack when entering libdeflate to avoid crashing in SSE/AVX + * code when called from an MSVC-compiled application. + */ +# define LIBDEFLATEAPI_STACKALIGN __attribute__((force_align_arg_pointer)) +#else +# define LIBDEFLATEAPI_STACKALIGN +#endif + +#define LIBDEFLATEAPI LIBDEFLATEAPI_ABI LIBDEFLATEAPI_STACKALIGN + +/* ========================================================================== */ +/* Compression */ +/* ========================================================================== */ + +struct libdeflate_compressor; + +/* + * libdeflate_alloc_compressor() allocates a new compressor that supports + * DEFLATE, zlib, and gzip compression. 'compression_level' is the compression + * level on a zlib-like scale but with a higher maximum value (1 = fastest, 6 = + * medium/default, 9 = slow, 12 = slowest). The return value is a pointer to + * the new compressor, or NULL if out of memory. + * + * Note: for compression, the sliding window size is defined at compilation time + * to 32768, the largest size permissible in the DEFLATE format. It cannot be + * changed at runtime. + * + * A single compressor is not safe to use by multiple threads concurrently. + * However, different threads may use different compressors concurrently. + */ +LIBDEFLATEEXPORT struct libdeflate_compressor * LIBDEFLATEAPI +libdeflate_alloc_compressor(int compression_level); + +/* + * libdeflate_deflate_compress() performs raw DEFLATE compression on a buffer of + * data. The function attempts to compress 'in_nbytes' bytes of data located at + * 'in' and write the results to 'out', which has space for 'out_nbytes_avail' + * bytes. The return value is the compressed size in bytes, or 0 if the data + * could not be compressed to 'out_nbytes_avail' bytes or fewer. + */ +LIBDEFLATEEXPORT size_t LIBDEFLATEAPI +libdeflate_deflate_compress(struct libdeflate_compressor *compressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail); + +/* + * libdeflate_deflate_compress_bound() returns a worst-case upper bound on the + * number of bytes of compressed data that may be produced by compressing any + * buffer of length less than or equal to 'in_nbytes' using + * libdeflate_deflate_compress() with the specified compressor. Mathematically, + * this bound will necessarily be a number greater than or equal to 'in_nbytes'. + * It may be an overestimate of the true upper bound. The return value is + * guaranteed to be the same for all invocations with the same compressor and + * same 'in_nbytes'. + * + * As a special case, 'compressor' may be NULL. This causes the bound to be + * taken across *any* libdeflate_compressor that could ever be allocated with + * this build of the library, with any options. + * + * Note that this function is not necessary in many applications. With + * block-based compression, it is usually preferable to separately store the + * uncompressed size of each block and to store any blocks that did not compress + * to less than their original size uncompressed. In that scenario, there is no + * need to know the worst-case compressed size, since the maximum number of + * bytes of compressed data that may be used would always be one less than the + * input length. You can just pass a buffer of that size to + * libdeflate_deflate_compress() and store the data uncompressed if + * libdeflate_deflate_compress() returns 0, indicating that the compressed data + * did not fit into the provided output buffer. + */ +LIBDEFLATEEXPORT size_t LIBDEFLATEAPI +libdeflate_deflate_compress_bound(struct libdeflate_compressor *compressor, + size_t in_nbytes); + +/* + * Like libdeflate_deflate_compress(), but stores the data in the zlib wrapper + * format. + */ +LIBDEFLATEEXPORT size_t LIBDEFLATEAPI +libdeflate_zlib_compress(struct libdeflate_compressor *compressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail); + +/* + * Like libdeflate_deflate_compress_bound(), but assumes the data will be + * compressed with libdeflate_zlib_compress() rather than with + * libdeflate_deflate_compress(). + */ +LIBDEFLATEEXPORT size_t LIBDEFLATEAPI +libdeflate_zlib_compress_bound(struct libdeflate_compressor *compressor, + size_t in_nbytes); + +/* + * Like libdeflate_deflate_compress(), but stores the data in the gzip wrapper + * format. + */ +LIBDEFLATEEXPORT size_t LIBDEFLATEAPI +libdeflate_gzip_compress(struct libdeflate_compressor *compressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail); + +/* + * Like libdeflate_deflate_compress_bound(), but assumes the data will be + * compressed with libdeflate_gzip_compress() rather than with + * libdeflate_deflate_compress(). + */ +LIBDEFLATEEXPORT size_t LIBDEFLATEAPI +libdeflate_gzip_compress_bound(struct libdeflate_compressor *compressor, + size_t in_nbytes); + +/* + * libdeflate_free_compressor() frees a compressor that was allocated with + * libdeflate_alloc_compressor(). If a NULL pointer is passed in, no action is + * taken. + */ +LIBDEFLATEEXPORT void LIBDEFLATEAPI +libdeflate_free_compressor(struct libdeflate_compressor *compressor); + +/* ========================================================================== */ +/* Decompression */ +/* ========================================================================== */ + +struct libdeflate_decompressor; + +/* + * libdeflate_alloc_decompressor() allocates a new decompressor that can be used + * for DEFLATE, zlib, and gzip decompression. The return value is a pointer to + * the new decompressor, or NULL if out of memory. + * + * This function takes no parameters, and the returned decompressor is valid for + * decompressing data that was compressed at any compression level and with any + * sliding window size. + * + * A single decompressor is not safe to use by multiple threads concurrently. + * However, different threads may use different decompressors concurrently. + */ +LIBDEFLATEEXPORT struct libdeflate_decompressor * LIBDEFLATEAPI +libdeflate_alloc_decompressor(void); + +/* + * Result of a call to libdeflate_deflate_decompress(), + * libdeflate_zlib_decompress(), or libdeflate_gzip_decompress(). + */ +enum libdeflate_result { + /* Decompression was successful. */ + LIBDEFLATE_SUCCESS = 0, + + /* Decompressed failed because the compressed data was invalid, corrupt, + * or otherwise unsupported. */ + LIBDEFLATE_BAD_DATA = 1, + + /* A NULL 'actual_out_nbytes_ret' was provided, but the data would have + * decompressed to fewer than 'out_nbytes_avail' bytes. */ + LIBDEFLATE_SHORT_OUTPUT = 2, + + /* The data would have decompressed to more than 'out_nbytes_avail' + * bytes. */ + LIBDEFLATE_INSUFFICIENT_SPACE = 3, +}; + +/* + * libdeflate_deflate_decompress() decompresses the DEFLATE-compressed stream + * from the buffer 'in' with compressed size up to 'in_nbytes' bytes. The + * uncompressed data is written to 'out', a buffer with size 'out_nbytes_avail' + * bytes. If decompression succeeds, then 0 (LIBDEFLATE_SUCCESS) is returned. + * Otherwise, a nonzero result code such as LIBDEFLATE_BAD_DATA is returned. If + * a nonzero result code is returned, then the contents of the output buffer are + * undefined. + * + * Decompression stops at the end of the DEFLATE stream (as indicated by the + * BFINAL flag), even if it is actually shorter than 'in_nbytes' bytes. + * + * libdeflate_deflate_decompress() can be used in cases where the actual + * uncompressed size is known (recommended) or unknown (not recommended): + * + * - If the actual uncompressed size is known, then pass the actual + * uncompressed size as 'out_nbytes_avail' and pass NULL for + * 'actual_out_nbytes_ret'. This makes libdeflate_deflate_decompress() fail + * with LIBDEFLATE_SHORT_OUTPUT if the data decompressed to fewer than the + * specified number of bytes. + * + * - If the actual uncompressed size is unknown, then provide a non-NULL + * 'actual_out_nbytes_ret' and provide a buffer with some size + * 'out_nbytes_avail' that you think is large enough to hold all the + * uncompressed data. In this case, if the data decompresses to less than + * or equal to 'out_nbytes_avail' bytes, then + * libdeflate_deflate_decompress() will write the actual uncompressed size + * to *actual_out_nbytes_ret and return 0 (LIBDEFLATE_SUCCESS). Otherwise, + * it will return LIBDEFLATE_INSUFFICIENT_SPACE if the provided buffer was + * not large enough but no other problems were encountered, or another + * nonzero result code if decompression failed for another reason. + */ +LIBDEFLATEEXPORT enum libdeflate_result LIBDEFLATEAPI +libdeflate_deflate_decompress(struct libdeflate_decompressor *decompressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail, + size_t *actual_out_nbytes_ret); + +/* + * Like libdeflate_deflate_decompress(), but adds the 'actual_in_nbytes_ret' + * argument. If decompression succeeds and 'actual_in_nbytes_ret' is not NULL, + * then the actual compressed size of the DEFLATE stream (aligned to the next + * byte boundary) is written to *actual_in_nbytes_ret. + */ +LIBDEFLATEEXPORT enum libdeflate_result LIBDEFLATEAPI +libdeflate_deflate_decompress_ex(struct libdeflate_decompressor *decompressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail, + size_t *actual_in_nbytes_ret, + size_t *actual_out_nbytes_ret); + +/* + * Like libdeflate_deflate_decompress(), but assumes the zlib wrapper format + * instead of raw DEFLATE. + */ +LIBDEFLATEEXPORT enum libdeflate_result LIBDEFLATEAPI +libdeflate_zlib_decompress(struct libdeflate_decompressor *decompressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail, + size_t *actual_out_nbytes_ret); + +/* + * Like libdeflate_deflate_decompress(), but assumes the gzip wrapper format + * instead of raw DEFLATE. + * + * If multiple gzip-compressed members are concatenated, then only the first + * will be decompressed. Use libdeflate_gzip_decompress_ex() if you need + * multi-member support. + */ +LIBDEFLATEEXPORT enum libdeflate_result LIBDEFLATEAPI +libdeflate_gzip_decompress(struct libdeflate_decompressor *decompressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail, + size_t *actual_out_nbytes_ret); + +/* + * Like libdeflate_gzip_decompress(), but adds the 'actual_in_nbytes_ret' + * argument. If 'actual_in_nbytes_ret' is not NULL and the decompression + * succeeds (indicating that the first gzip-compressed member in the input + * buffer was decompressed), then the actual number of input bytes consumed is + * written to *actual_in_nbytes_ret. + */ +LIBDEFLATEEXPORT enum libdeflate_result LIBDEFLATEAPI +libdeflate_gzip_decompress_ex(struct libdeflate_decompressor *decompressor, + const void *in, size_t in_nbytes, + void *out, size_t out_nbytes_avail, + size_t *actual_in_nbytes_ret, + size_t *actual_out_nbytes_ret); + +/* + * libdeflate_free_decompressor() frees a decompressor that was allocated with + * libdeflate_alloc_decompressor(). If a NULL pointer is passed in, no action + * is taken. + */ +LIBDEFLATEEXPORT void LIBDEFLATEAPI +libdeflate_free_decompressor(struct libdeflate_decompressor *decompressor); + +/* ========================================================================== */ +/* Checksums */ +/* ========================================================================== */ + +/* + * libdeflate_adler32() updates a running Adler-32 checksum with 'len' bytes of + * data and returns the updated checksum. When starting a new checksum, the + * required initial value for 'adler' is 1. This value is also returned when + * 'buffer' is specified as NULL. + */ +LIBDEFLATEEXPORT uint32_t LIBDEFLATEAPI +libdeflate_adler32(uint32_t adler32, const void *buffer, size_t len); + + +/* + * libdeflate_crc32() updates a running CRC-32 checksum with 'len' bytes of data + * and returns the updated checksum. When starting a new checksum, the required + * initial value for 'crc' is 0. This value is also returned when 'buffer' is + * specified as NULL. + */ +LIBDEFLATEEXPORT uint32_t LIBDEFLATEAPI +libdeflate_crc32(uint32_t crc, const void *buffer, size_t len); + +#ifdef __cplusplus +} +#endif + +#endif /* LIBDEFLATE_H */ diff --git a/libdeflate/libdeflatestatic.lib b/libdeflate/libdeflatestatic.lib new file mode 100644 index 0000000..7702e54 Binary files /dev/null and b/libdeflate/libdeflatestatic.lib differ diff --git a/main.cpp b/main.cpp deleted file mode 100644 index 0bbad34..0000000 --- a/main.cpp +++ /dev/null @@ -1,137 +0,0 @@ -/** - * @file memfs-main.c - * - * @copyright 2015-2020 Bill Zissimopoulos - */ - /* - * This file is part of WinFsp. - * - * You can redistribute it and/or modify it under the terms of the GNU - * General Public License version 3 as published by the Free Software - * Foundation. - * - * Licensees holding a valid commercial license may use this software - * in accordance with the commercial license agreement provided in - * conjunction with the software. The terms and conditions of any such - * commercial license agreement shall govern, supersede, and render - * ineffective any application of the GPLv3 license to this software, - * notwithstanding of any reference thereto in the software or - * associated repository. - */ - -#include -#include "memfs/memfs.h" -#include - -#define info(format, ...) FspServiceLog(EVENTLOG_INFORMATION_TYPE, format, __VA_ARGS__) -#define warn(format, ...) FspServiceLog(EVENTLOG_WARNING_TYPE, format, __VA_ARGS__) -#define fail(format, ...) FspServiceLog(EVENTLOG_ERROR_TYPE, format, __VA_ARGS__) - -static void InitializeFilesystem(MEMFS* Memfs) { - CreateFsFile(Memfs, L"\\eee", false); - //CreateFsFile(Memfs, L"\\eee\\text.txt", false); -} - -static PVOID FileOpen(PCWSTR fileName, UINT64* fileSize) { - wprintf(L"OPENING %s\n", fileName); - *fileSize = 400; - return (void*)34; -} - -static void FileClose(PVOID Handle) { - wprintf(L"CLOSING %d\n", (int)Handle); - // do nothing -} - -static void FileRead(PVOID Handle, PVOID Buffer, UINT64 offset, ULONG length, ULONG* bytesRead) { - wprintf(L"READING %d\n", (int)Handle); - memset(Buffer, 0x0F, length); - *bytesRead = length; -} - -NTSTATUS SvcStart(FSP_SERVICE* Service, ULONG argc, PWSTR* argv) -{ - MEMFS* Memfs = 0; - NTSTATUS Result; - - FspDebugLogSetHandle(GetStdHandle(STD_ERROR_HANDLE)); - - MEMFS_FILE_PROVIDER* Provider = CreateProvider(FileOpen, FileClose, FileRead); - - Result = MemfsCreateFunnel( - MemfsDisk, // flags - INFINITE, // file timeout - 1024, // max file nodes/files - 0, // file system name - 0, // volume prefix - Provider, - &Memfs); - if (!NT_SUCCESS(Result)) - { - fail(L"cannot create MEMFS"); - goto exit; - } - - InitializeFilesystem(Memfs); - - FspFileSystemSetDebugLog(MemfsFileSystem(Memfs), 0); // can also be -1 for all flags - - { - PSECURITY_DESCRIPTOR RootSecurity; - - const TCHAR* rootSddl = - SDDL_DACL SDDL_DELIMINATOR - SDDL_ACE_COND_BEGIN - SDDL_ACCESS_ALLOWED SDDL_SEPERATOR // Allowed (ace_type) - SDDL_OBJECT_INHERIT SDDL_CONTAINER_INHERIT SDDL_SEPERATOR // Inherit to containers and objects (ace_flags) - SDDL_GENERIC_READ SDDL_GENERIC_EXECUTE SDDL_SEPERATOR // Allow reads and executes (rights) - SDDL_SEPERATOR // object_guid - SDDL_SEPERATOR // inherit_object_guid - SDDL_EVERYONE // Give rights to everyone (account_sid) - SDDL_ACE_COND_END; - // D:(A;OICI;GRGX;;;WD) - - if (!ConvertStringSecurityDescriptorToSecurityDescriptor(rootSddl, SDDL_REVISION_1, &RootSecurity, NULL)) { - fail(L"invalid sddl: %08x", FspNtStatusFromWin32(GetLastError())); - goto exit; - } - Result = FspFileSystemSetMountPointEx(MemfsFileSystem(Memfs), L"C:\\aaaa", RootSecurity); - LocalFree(RootSecurity); - if (!NT_SUCCESS(Result)) - { - fail(L"cannot mount MEMFS %08x", Result); - goto exit; - } - } - - Result = MemfsStart(Memfs); - if (!NT_SUCCESS(Result)) - { - fail(L"cannot start MEMFS %08x", Result); - goto exit; - } - - Service->UserContext = Memfs; - Result = STATUS_SUCCESS; - -exit: - if (!NT_SUCCESS(Result) && 0 != Memfs) - MemfsDelete(Memfs); - - return Result; -} - -NTSTATUS SvcStop(FSP_SERVICE* Service) -{ - MEMFS* Memfs = (MEMFS*)Service->UserContext; - - MemfsStop(Memfs); - MemfsDelete(Memfs); - - return STATUS_SUCCESS; -} - -int wmain(int argc, wchar_t** argv) -{ - return FspServiceRun(L"EGL2", SvcStart, SvcStop, 0); -} \ No newline at end of file diff --git a/memfs/memfs.h b/memfs/memfs.h deleted file mode 100644 index b48a938..0000000 --- a/memfs/memfs.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * @file memfs.h - * - * @copyright 2015-2020 Bill Zissimopoulos - */ - /* - * This file is part of WinFsp. - * - * You can redistribute it and/or modify it under the terms of the GNU - * General Public License version 3 as published by the Free Software - * Foundation. - * - * Licensees holding a valid commercial license may use this software - * in accordance with the commercial license agreement provided in - * conjunction with the software. The terms and conditions of any such - * commercial license agreement shall govern, supersede, and render - * ineffective any application of the GPLv3 license to this software, - * notwithstanding of any reference thereto in the software or - * associated repository. - */ - -#ifndef MEMFS_H_INCLUDED -#define MEMFS_H_INCLUDED - -#include -#include - -#define MEMFS_MAX_PATH 512 -FSP_FSCTL_STATIC_ASSERT(MEMFS_MAX_PATH > MAX_PATH, - "MEMFS_MAX_PATH must be greater than MAX_PATH."); - -#define MEMFS_SECTOR_SIZE 512 -#define MEMFS_SECTORS_PER_ALLOCATION_UNIT 1 - -#ifdef __cplusplus -extern "C" { -#endif - - typedef struct _MEMFS MEMFS; - - typedef struct _MEMFS_FILE_PROVIDER MEMFS_FILE_PROVIDER; - - enum - { - MemfsDisk = 0x00000000, - MemfsNet = 0x00000001, - MemfsDeviceMask = 0x0000000f, - MemfsCaseInsensitive = 0x80000000, - }; - - NTSTATUS MemfsCreateFunnel( - ULONG Flags, - ULONG FileInfoTimeout, - ULONG MaxFileNodes, - PWSTR FileSystemName, - PWSTR VolumePrefix, - MEMFS_FILE_PROVIDER* FileProvider, - MEMFS** PMemfs); - VOID MemfsDelete(MEMFS* Memfs); - NTSTATUS MemfsStart(MEMFS* Memfs); - VOID MemfsStop(MEMFS* Memfs); - FSP_FILE_SYSTEM* MemfsFileSystem(MEMFS* Memfs); - - MEMFS_FILE_PROVIDER* CreateProvider( - std::function Open, - std::function Close, - std::function Read); - void CloseProvider(MEMFS_FILE_PROVIDER* Provider); - - NTSTATUS CreateFsFile(MEMFS* Memfs, - PWSTR FileName, BOOLEAN Directory); - -#ifdef __cplusplus -} -#endif - -#endif - diff --git a/storage/compression.cpp b/storage/compression.cpp new file mode 100644 index 0000000..03383b3 --- /dev/null +++ b/storage/compression.cpp @@ -0,0 +1,115 @@ +#include "compression.h" +#include "storage.h" + +#include +#include +#include + +bool ZlibDecompress(FILE* File, DECOMPRESS_ALLOCATOR Allocator) { + uint32_t uncompressedSize; + fread(&uncompressedSize, sizeof(uint32_t), 1, File); + + auto pos = ftell(File); + fseek(File, 0, SEEK_END); + auto inBufSize = ftell(File) - pos; + fseek(File, pos, SEEK_SET); + + char* inBuffer = new char[inBufSize]; + fread(inBuffer, 1, inBufSize, File); + + auto decompressor = libdeflate_alloc_decompressor(); + + auto& outBuffer = Allocator(uncompressedSize); + libdeflate_zlib_decompress(decompressor, inBuffer, inBufSize, outBuffer.get(), uncompressedSize, NULL); + + libdeflate_free_decompressor(decompressor); + delete[] inBuffer; + return true; +} + +bool LZ4Decompress(FILE* File, DECOMPRESS_ALLOCATOR Allocator) { + uint32_t uncompressedSize; + fread(&uncompressedSize, sizeof(uint32_t), 1, File); + + auto pos = ftell(File); + fseek(File, 0, SEEK_END); + auto inBufSize = ftell(File) - pos; + fseek(File, pos, SEEK_SET); + + char* inBuffer = new char[inBufSize]; + fread(inBuffer, 1, inBufSize, File); + + auto& outBuffer = Allocator(uncompressedSize); + LZ4_decompress_fast(inBuffer, outBuffer.get(), uncompressedSize); + + delete[] inBuffer; + return true; +} + +bool ZlibCompress(uint32_t Flags, const char* Buffer, uint32_t BufferSize, char** POutBuffer, uint32_t* POutBufferSize) { + int compression_level; + if (Flags & StorageCompressFastest) { + compression_level = 1; + } + else if (Flags & StorageCompressFast) { + compression_level = 4; + } + else if (Flags & StorageCompressNormal) { + compression_level = 6; + } + else if (Flags & StorageCompressSlow) { + compression_level = 9; + } + else if (Flags & StorageCompressSlowest) { + compression_level = 12; + } + else { + return false; + } + + auto compressor = libdeflate_alloc_compressor(compression_level); + + uint32_t outBufSize = libdeflate_zlib_compress_bound(compressor, BufferSize); + char* outBuffer = new char[outBufSize]; + + uint32_t compressedSize = libdeflate_zlib_compress(compressor, Buffer, BufferSize, outBuffer, outBufSize); + + *POutBuffer = outBuffer; + *POutBufferSize = compressedSize; + + libdeflate_free_compressor(compressor); +} + +bool LZ4Compress(uint32_t Flags, const char* Buffer, uint32_t BufferSize, char** POutBuffer, uint32_t* POutBufferSize) { + int compression_level; + if (Flags & StorageCompressFastest) { + compression_level = LZ4HC_CLEVEL_MIN; + } + else if (Flags & StorageCompressFast) { + compression_level = 6; + } + else if (Flags & StorageCompressNormal) { + compression_level = LZ4HC_CLEVEL_DEFAULT; + } + else if (Flags & StorageCompressSlow) { + compression_level = LZ4HC_CLEVEL_OPT_MIN; // uses "HC" at this point + } + else if (Flags & StorageCompressSlowest) { + compression_level = LZ4HC_CLEVEL_MAX; + } + else { + return false; + } + + uint32_t outBufSize = LZ4_COMPRESSBOUND(BufferSize); + char* outBuffer = new char[outBufSize]; + + uint32_t compressedSize = LZ4_compress_HC(Buffer, outBuffer, BufferSize, outBufSize, compression_level); + + *POutBuffer = outBuffer; + *POutBufferSize = compressedSize; +} + +void DeleteCompressBuffer(char* OutBuffer) { + delete[] OutBuffer; +} \ No newline at end of file diff --git a/storage/compression.h b/storage/compression.h new file mode 100644 index 0000000..666e2bf --- /dev/null +++ b/storage/compression.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +typedef std::function& (uint32_t size)> DECOMPRESS_ALLOCATOR; + +bool ZlibDecompress(FILE* File, DECOMPRESS_ALLOCATOR Allocator); + +bool LZ4Decompress(FILE* File, DECOMPRESS_ALLOCATOR Allocator); + +bool ZlibCompress(uint32_t Flags, const char* Buffer, uint32_t BufferSize, char** OutBuffer, uint32_t* POutBufferSize); + +bool LZ4Compress(uint32_t Flags, const char* Buffer, uint32_t BufferSize, char** OutBuffer, uint32_t* POutBufferSize); diff --git a/storage/sha.h b/storage/sha.h new file mode 100644 index 0000000..fb6686f --- /dev/null +++ b/storage/sha.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +bool VerifyHash(const char* input, uint32_t inputSize, const char Sha[20]) { + char calculatedHash[20]; + SHA1((const uint8_t*)input, inputSize, (uint8_t*)calculatedHash); + + return !memcmp(Sha, calculatedHash, 20); +} \ No newline at end of file diff --git a/storage/storage.cpp b/storage/storage.cpp new file mode 100644 index 0000000..0581725 --- /dev/null +++ b/storage/storage.cpp @@ -0,0 +1,428 @@ +#include "storage.h" +#include "../web/http.h" +#include "../containers/iterable_queue.h" +#include "compression.h" +#include "sha.h" + +#include +#include +#include + +namespace fs = std::filesystem; + +#define STORAGE_CHUNKS_RESERVE 64 + +enum class CHUNK_STATUS { + Unavailable, // Readable from download + Grabbing, // Downloading + Available, // Readable from local copy + Reading, // Reading from local copy + Readable // Readable from memory +}; + +typedef struct _CHUNK_POOL_DATA { + std::unique_ptr Buffer; + std::condition_variable CV; + std::mutex Mutex; + CHUNK_STATUS Status; +} CHUNK_POOL_DATA; + +auto hash = [](const char* n) { return (*((uint64_t*)n)) ^ (*(((uint64_t*)n) + 1)); }; +auto equal = [](const char* a, const char* b) {return !memcmp(a, b, 16); }; +// this can be an ordered map, but i'm unsure about the memory usage of this, even if all 81k chunks are read +typedef iterable_queue>> STORAGE_CHUNK_POOL_LOOKUP; + +typedef struct _STORAGE { + fs::path cachePath; + uint32_t flags; + std::string CloudDir; // CloudDir also includes the /ChunksV3/ part, though + std::unique_ptr Client; + std::unique_ptr ChunkPoolMutex; + STORAGE_CHUNK_POOL_LOOKUP ChunkPool; +} STORAGE; + +bool StorageCreate( + uint32_t Flags, + const wchar_t* CacheLocation, + const char* ChunkHost, + const char* CloudDir, + STORAGE** PStorage) { + STORAGE* Storage = new STORAGE; + Storage->cachePath = fs::path(CacheLocation); + Storage->flags = Flags; + Storage->CloudDir = CloudDir; + Storage->Client = std::make_unique(ChunkHost); + Storage->ChunkPoolMutex = std::make_unique(); + *PStorage = Storage; + return true; +} + +void StorageDelete(STORAGE* Storage) { + // delete buffers and stuff in pool lookup table + delete Storage; +} + +inline CHUNK_STATUS StorageGetChunkStatus(STORAGE* Storage, char Guid[16]) { + char GuidString[33]; + sprintf(GuidString, "%016llX%016llX", ntohll(*(uint64_t*)Guid), ntohll(*(uint64_t*)(Guid + 8))); + + char GuidFolder[3]; + memcpy(GuidFolder, GuidString, 2); + GuidFolder[2] = '\0'; + + if (fs::status(Storage->cachePath / GuidFolder / GuidString).type() != fs::file_type::regular) { + return CHUNK_STATUS::Unavailable; + } + else { + return CHUNK_STATUS::Available; + } +} + +std::weak_ptr StorageGetPoolData(STORAGE* Storage, char Guid[16]) { + std::lock_guard statusLock(*Storage->ChunkPoolMutex); + for (auto& chunk : Storage->ChunkPool) { + if (!memcmp(Guid, chunk.first, 16)) { + return chunk.second; + } + } + + if (Storage->ChunkPool.size() == STORAGE_CHUNKS_RESERVE) + { + Storage->ChunkPool.pop(); + } + + auto data = std::make_shared(); + + data->Buffer = std::unique_ptr(); + data->Status = StorageGetChunkStatus(Storage, Guid); + + Storage->ChunkPool.push(std::make_pair(Guid, data)); + return data; +} + +std::unique_ptr& StorageGetBuffer(STORAGE* Storage, char Guid[16]) { + return StorageGetPoolData(Storage, Guid).lock()->Buffer; +} + +void StorageSetBuffer(std::shared_ptr data, uint32_t size) { + data->Buffer.reset(new char[size]); +} + +#pragma pack(push, 1) +typedef struct _STORAGE_CHUNK_HEADER { + uint16_t version; + uint16_t flags; +} STORAGE_CHUNK_HEADER; + +#define CHUNK_HEADER_MAGIC 0xB1FE3AA2 +typedef struct _STORAGE_CDN_CHUNK_HEADER { + uint32_t Magic; + uint32_t Version; + uint32_t HeaderSize; + uint32_t DataSizeCompressed; + char Guid[16]; + uint64_t RollingHash; + uint8_t StoredAs; // EChunkStorageFlags +} STORAGE_CDN_CHUNK_HEADER; + +typedef struct _STORAGE_CDN_CHUNK_HEADER_V2 { + char SHAHash[20]; + uint8_t HashType; // EChunkHashFlags +} STORAGE_CDN_CHUNK_HEADER_V2; + +typedef struct _STORAGE_CDN_CHUNK_HEADER_V3 { + uint32_t DataSizeUncompressed; +} STORAGE_CDN_CHUNK_HEADER_V3; +#pragma pack(pop) + +inline bool StorageRead(fs::path path, std::function&(uint32_t size)> allocator) { + auto fp = fopen(path.string().c_str(), "rb"); + STORAGE_CHUNK_HEADER header; + fread(&header, sizeof(STORAGE_CHUNK_HEADER), 1, fp); + if (header.version != 0) { + printf("bad version!\n"); + return false; + } + if (header.flags & StorageDecompressed) { + auto pos = ftell(fp); + fseek(fp, 0, SEEK_END); + auto inBufSize = ftell(fp) - pos; + fseek(fp, pos, SEEK_SET); + + auto& buffer = allocator(inBufSize); + fread(buffer.get(), 1, inBufSize, fp); + fclose(fp); + return true; + } + else if (header.flags & StorageCompressZlib) { // zlib compressed + auto result = ZlibDecompress(fp, allocator); + + fclose(fp); + return result; + } + else if (header.flags & StorageCompressLZ4) { // lz4 compressed + auto result = LZ4Decompress(fp, allocator); + + fclose(fp); + return result; + } + fclose(fp); + printf("unknown read flag!\n"); + return false; +} + +inline void StorageWrite(const char* Path, uint16_t Flags, uint32_t DecompressedSize, const char* Buffer, uint32_t BufferSize) { + auto fp = fopen(Path, "wb"); + if (!fp) + printf("ERRNO %s: %d\n", Path, errno); + STORAGE_CHUNK_HEADER chunkHeader; + chunkHeader.version = 0; + chunkHeader.flags = Flags; + fwrite(&chunkHeader, sizeof(STORAGE_CHUNK_HEADER), 1, fp); + if (!(Flags & StorageDecompressed)) { // Compressed chunks write the decompressed size + fwrite(&DecompressedSize, sizeof(uint32_t), 1, fp); + } + fwrite(Buffer, 1, BufferSize, fp); + fclose(fp); +} + +bool StorageChunkDownloaded(STORAGE* Storage, MANIFEST_CHUNK* Chunk) { + return StorageGetChunkStatus(Storage, ManifestChunkGetGuid(Chunk)) != CHUNK_STATUS::Unavailable; +} + +bool StorageVerifyChunk(STORAGE* Storage, MANIFEST_CHUNK* Chunk) { + if (StorageChunkDownloaded(Storage, Chunk)) { + auto guid = ManifestChunkGetGuid(Chunk); + + char GuidString[33]; + sprintf(GuidString, "%016llX%016llX", ntohll(*(uint64_t*)guid), ntohll(*(uint64_t*)(guid + 8))); + + char GuidFolder[3]; + memcpy(GuidFolder, GuidString, 2); + GuidFolder[2] = '\0'; + + std::unique_ptr _buf; + uint32_t _bufSize; + StorageRead(Storage->cachePath / GuidFolder / GuidString, [&](uint32_t size) -> std::unique_ptr& { + _buf.reset(new char[size]); + _bufSize = size; + return _buf; + }); + + if (!VerifyHash(_buf.get(), _bufSize, ManifestChunkGetSha1(Chunk))) { + fs::remove(Storage->cachePath / GuidFolder / GuidString); + StorageDownloadChunk(Storage, Chunk, [](const char* buf, uint32_t bufSize) {}); + } + return true; + } + else { + return false; + } +} + +void StorageDownloadChunk(STORAGE* Storage, MANIFEST_CHUNK* Chunk, std::function DataCallback) { + std::vector chunkData; + { + chunkData.reserve(8192); + bool chunkDataReserved = false; + + char UrlBuffer[256]; + strcpy(UrlBuffer, Storage->CloudDir.c_str()); + ManifestChunkAppendUrl(Chunk, UrlBuffer); + Storage->Client->Get(UrlBuffer, + [&](const char* data, uint64_t data_length) { + chunkData.insert(chunkData.end(), data, data + data_length); + return true; + }, + [&](uint64_t len, uint64_t total) { + if (!chunkDataReserved) { + chunkData.reserve(total); + chunkDataReserved = true; + } + return true; + } + ); + } + + uint32_t decompressedSize = 1024 * 1024; + + auto headerv1 = *(STORAGE_CDN_CHUNK_HEADER*)chunkData.data(); + auto chunkPos = sizeof(STORAGE_CDN_CHUNK_HEADER); + if (headerv1.Magic != CHUNK_HEADER_MAGIC) { + printf("magic invalid\n"); + return; + } + if (headerv1.Version >= 2) { + auto headerv2 = *(STORAGE_CDN_CHUNK_HEADER_V2*)(chunkData.data() + chunkPos); + chunkPos += sizeof(STORAGE_CDN_CHUNK_HEADER_V2); + if (headerv1.Version >= 3) { + auto headerv3 = *(STORAGE_CDN_CHUNK_HEADER_V3*)(chunkData.data() + chunkPos); + decompressedSize = headerv3.DataSizeUncompressed; + + if (headerv1.Version > 3) { // version past 3 + chunkPos = headerv1.HeaderSize; + } + } + } + + if (headerv1.StoredAs & 0x02) // encrypted + { + printf("encrypted?\n"); + return; // no support yet, i have never seen this used in practice + } + + char GuidString[33]; + { + auto guid = ManifestChunkGetGuid(Chunk); + sprintf(GuidString, "%016llX%016llX", ntohll(*(uint64_t*)guid), ntohll(*(uint64_t*)(guid + 8))); + } + + char GuidFolder[3]; + memcpy(GuidFolder, GuidString, 2); + GuidFolder[2] = '\0'; + + auto guidPathStr = (Storage->cachePath / GuidFolder / GuidString).string(); + auto guidPath = guidPathStr.c_str(); + auto bufferPtr = chunkData.data() + chunkPos; + + if (headerv1.StoredAs & 0x01) // compressed + { + auto data = std::make_unique(decompressedSize); + { + auto decompressor = libdeflate_alloc_decompressor(); + auto result = libdeflate_zlib_decompress(decompressor, bufferPtr, headerv1.DataSizeCompressed, data.get(), decompressedSize, NULL); + DataCallback(data.get(), decompressedSize); + libdeflate_free_decompressor(decompressor); + } + + if (Storage->flags & StorageCompressed) { + StorageWrite(guidPath, StorageCompressZlib, decompressedSize, bufferPtr, headerv1.DataSizeCompressed); + } + else if (Storage->flags & StorageDecompressed) { + StorageWrite(guidPath, StorageDecompressed, 0, data.get(), decompressedSize); + } + else if (Storage->flags & StorageCompressZlib) { + char* compressedBuffer; + uint32_t compressedBufferSize; + ZlibCompress(Storage->flags, data.get(), decompressedSize, &compressedBuffer, &compressedBufferSize); + StorageWrite(guidPath, StorageCompressZlib, decompressedSize, compressedBuffer, compressedBufferSize); + delete[] compressedBuffer; + } + else if (Storage->flags & StorageCompressLZ4) { + char* compressedBuffer; + uint32_t compressedBufferSize; + LZ4Compress(Storage->flags, data.get(), decompressedSize, &compressedBuffer, &compressedBufferSize); + StorageWrite(guidPath, StorageCompressLZ4, decompressedSize, compressedBuffer, compressedBufferSize); + delete[] compressedBuffer; + } + } + else { + DataCallback(bufferPtr, decompressedSize); + + if (Storage->flags & (StorageCompressed | StorageDecompressed)) { + StorageWrite(guidPath, StorageDecompressed, 0, bufferPtr, decompressedSize); + } + else if (Storage->flags & StorageCompressZlib) { + char* compressedBuffer; + uint32_t compressedBufferSize; + ZlibCompress(Storage->flags, bufferPtr, decompressedSize, &compressedBuffer, &compressedBufferSize); + StorageWrite(guidPath, StorageCompressZlib, decompressedSize, compressedBuffer, compressedBufferSize); + delete[] compressedBuffer; + } + else if (Storage->flags & StorageCompressLZ4) { + char* compressedBuffer; + uint32_t compressedBufferSize; + LZ4Compress(Storage->flags, bufferPtr, decompressedSize, &compressedBuffer, &compressedBufferSize); + StorageWrite(guidPath, StorageCompressLZ4, decompressedSize, compressedBuffer, compressedBufferSize); + delete[] compressedBuffer; + } + } +} + +// thread safe, downloads if needed, etc. +void StorageDownloadChunkPart(STORAGE* Storage, MANIFEST_CHUNK* Chunk, uint32_t Offset, uint32_t Size, char* Buffer) { + auto guid = ManifestChunkGetGuid(Chunk); + auto data = StorageGetPoolData(Storage, guid).lock(); + switch (data->Status) + { + case CHUNK_STATUS::Unavailable: +redownloadChunk: + { + // download + data->Status = CHUNK_STATUS::Grabbing; + + StorageDownloadChunk(Storage, Chunk, [&](const char* Buffer_, uint32_t BufferSize) { + std::unique_lock lck(data->Mutex); + StorageSetBuffer(data, BufferSize); + memcpy(data->Buffer.get(), Buffer_, BufferSize); + + memcpy(Buffer, Buffer_ + Offset, Size); + + data->Status = CHUNK_STATUS::Readable; + data->CV.notify_all(); + }); + break; + } + case CHUNK_STATUS::Available: + { + // read from file + data->Status = CHUNK_STATUS::Reading; + + char GuidString[33]; + sprintf(GuidString, "%016llX%016llX", ntohll(*(uint64_t*)guid), ntohll(*(uint64_t*)(guid + 8))); + + char GuidFolder[3]; + memcpy(GuidFolder, GuidString, 2); + GuidFolder[2] = '\0'; + + char* _buf; + uint32_t _bufSize; + StorageRead(Storage->cachePath / GuidFolder / GuidString, [&](uint32_t size) -> std::unique_ptr& { + auto data = StorageGetPoolData(Storage, guid).lock(); + StorageSetBuffer(data, size); + _buf = data->Buffer.get(); + _bufSize = size; + return data->Buffer; + }); + + if (Storage->flags & StorageVerifyHashes) { + if (!VerifyHash(_buf, _bufSize, ManifestChunkGetSha1(Chunk))) { + fs::remove(Storage->cachePath / GuidFolder / GuidString); + data->Status = CHUNK_STATUS::Unavailable; + StorageDownloadChunkPart(Storage, Chunk, Offset, Size, Buffer); + return; + } + } + + std::unique_lock lck(data->Mutex); + data->Status = CHUNK_STATUS::Readable; + data->CV.notify_all(); + + memcpy(Buffer, StorageGetBuffer(Storage, guid).get() + Offset, Size); + break; + } + case CHUNK_STATUS::Grabbing: // downloading from server, wait until mutex releases + case CHUNK_STATUS::Reading: // reading from file, wait until mutex releases + { + std::unique_lock lck(data->Mutex); + while (data->Status != CHUNK_STATUS::Readable) data->CV.wait(lck); + + memcpy(Buffer, StorageGetBuffer(Storage, guid).get() + Offset, Size); + break; + } + case CHUNK_STATUS::Readable: // available in memory pool + { + memcpy(Buffer, StorageGetBuffer(Storage, guid).get() + Offset, Size); + break; + } + default: + // h o w + break; + } +} + +void StorageDownloadChunkPart(STORAGE* Storage, MANIFEST_CHUNK_PART* ChunkPart, char* Buffer) { + uint32_t Offset, Size; + ManifestFileChunkGetData(ChunkPart, &Offset, &Size); + StorageDownloadChunkPart(Storage, ManifestFileChunkGetChunk(ChunkPart), Offset, Size, Buffer); +} \ No newline at end of file diff --git a/storage/storage.h b/storage/storage.h new file mode 100644 index 0000000..0b15a2d --- /dev/null +++ b/storage/storage.h @@ -0,0 +1,37 @@ +#pragma once +#include +#include +#include "../web/manifest.h" + +typedef struct _STORAGE STORAGE; + +enum +{ + StorageDecompressed = 0x00000001, // Chunks decompressed to solid blocks + StorageCompressed = 0x00000002, // Chunks stay compressed in their downloaded form (this flag isn't used in chunk cache files) + + StorageCompressZlib = 0x00000004, // Chunks are recompressed with Zlib + StorageCompressLZ4 = 0x00000008, // Chunks are recompressed with LZ4 + + StorageCompressFastest = 0x00000010, // Zlib = 1 + StorageCompressFast = 0x00000020, // Zlib = 4 + StorageCompressNormal = 0x00000040, // Zlib = 6 + StorageCompressSlow = 0x00000080, // Zlib = 9 + StorageCompressSlowest = 0x00000100, // Zlib = 12 + + StorageVerifyHashes = 0x00001000, // Verify SHA hashes of downloaded chunks when reading and redownload if invalid +}; + +bool StorageCreate( + uint32_t Flags, + const wchar_t* CacheLocation, + const char* ChunkHost, + const char* CloudDir, + STORAGE** PStorage); +void StorageDelete(STORAGE* Storage); + +bool StorageChunkDownloaded(STORAGE* Storage, MANIFEST_CHUNK* Chunk); +bool StorageVerifyChunk(STORAGE* Storage, MANIFEST_CHUNK* Chunk); +void StorageDownloadChunk(STORAGE* Storage, MANIFEST_CHUNK* Chunk, std::function DataCallback); +void StorageDownloadChunkPart(STORAGE* Storage, MANIFEST_CHUNK* Chunk, uint32_t Offset, uint32_t Size, char* Buffer); +void StorageDownloadChunkPart(STORAGE* Storage, MANIFEST_CHUNK_PART* ChunkPart, char* Buffer); diff --git a/web/http.h b/web/http.h new file mode 100644 index 0000000..f3caf80 --- /dev/null +++ b/web/http.h @@ -0,0 +1,4 @@ +#pragma once + +#include "httplib.h" +#include "url.hh" \ No newline at end of file diff --git a/web/httplib.cc b/web/httplib.cc new file mode 100644 index 0000000..d901514 --- /dev/null +++ b/web/httplib.cc @@ -0,0 +1,4007 @@ +#include "httplib.h" +namespace httplib { + + /* + * Implementation + */ + + namespace detail { + + bool is_hex(char c, int& v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } + else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } + else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; + } + + bool from_hex_to_i(const std::string& s, size_t i, size_t cnt, + int& val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } + else { + return false; + } + } + return true; + } + + std::string from_i_to_hex(size_t n) { + const char* charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; + } + + size_t to_utf8(int code, char* buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } + else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } + else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } + else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } + else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } + else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; + } + + // NOTE: This code came up with the following stackoverflow post: + // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c + std::string base64_encode(const std::string& in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + int val = 0; + int valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; + } + + bool is_file(const std::string& path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); + } + + bool is_dir(const std::string& path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); + } + + bool is_valid_path(const std::string& path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } + else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } + else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; + } + + void read_file(const std::string& path, std::string& out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], size); + } + + std::string file_extension(const std::string& path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); + } + + template void split(const char* b, const char* e, char d, Fn fn) { + int i = 0; + int beg = 0; + + while (e ? (b + i != e) : (b[i] != '\0')) { + if (b[i] == d) { + fn(&b[beg], &b[i]); + beg = i + 1; + } + i++; + } + + if (i) { fn(&b[beg], &b[i]); } + } + + // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` + // to store data. The call can set memory on stack for performance. + class stream_line_reader { + public: + stream_line_reader(Stream& strm, char* fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + + const char* ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } + else { + return glowable_buffer_.data(); + } + } + + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } + else { + return glowable_buffer_.size(); + } + } + + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } + + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } + else if (n == 0) { + if (i == 0) { + return false; + } + else { + break; + } + } + + append(byte); + + if (byte == '\n') { break; } + } + + return true; + } + + private: + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } + else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream& strm_; + char* fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; + }; + + int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif + } + + int select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return poll(&pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); +#endif + } + + int select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return poll(&pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); +#endif + } + + bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + if (poll(&pfd_read, 1, timeout) > 0 && + pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; +#else + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + if (select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 && + (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; +#endif + } + + class SocketStream : public Stream { + public: + SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char* ptr, size_t size) override; + ssize_t write(const char* ptr, size_t size) override; + std::string get_remote_addr() const override; + + private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + }; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + class SSLSocketStream : public Stream { + public: + SSLSocketStream(socket_t sock, SSL* ssl, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char* ptr, size_t size) override; + ssize_t write(const char* ptr, size_t size) override; + std::string get_remote_addr() const override; + + private: + socket_t sock_; + SSL* ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + }; +#endif + + class BufferStream : public Stream { + public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char* ptr, size_t size) override; + ssize_t write(const char* ptr, size_t size) override; + std::string get_remote_addr() const override; + + const std::string& get_buffer() const; + + private: + std::string buffer; + size_t position = 0; + }; + + template + bool process_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + assert(keep_alive_max_count > 0); + + auto ret = false; + + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } + else { // keep_alive_max_count is 0 or 1 + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(strm, true, dummy_connection_close); + } + + return ret; + } + + template + bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + auto ret = process_socket(is_client_request, sock, keep_alive_max_count, + read_timeout_sec, read_timeout_usec, callback); + close_socket(sock); + return ret; + } + + int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif + } + + template + socket_t create_socket(const char* host, int port, Fn fn, + int socket_flags = 0) { +#ifdef _WIN32 +#define SO_SYNCHRONOUS_NONALERT 0x20 +#define SO_OPENTYPE 0x7008 + + int opt = SO_SYNCHRONOUS_NONALERT; + setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, + sizeof(opt)); +#endif + + // Get address info + struct addrinfo hints; + struct addrinfo* result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; + + auto service = std::to_string(port); + + if (getaddrinfo(host, service.c_str(), &hints, &result)) { + return INVALID_SOCKET; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { continue; } + +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } +#endif + + // Make 'reuse address' option available + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); +#endif + + // bind or connect + if (fn(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; + } + + void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif + } + + bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif + } + + bool bind_ip_address(socket_t sock, const char* host) { + struct addrinfo hints; + struct addrinfo* result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host, "0", &hints, &result)) { return false; } + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto& ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; + } + + std::string if2ip(const std::string& ifn) { +#ifndef _WIN32 + struct ifaddrs* ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); +#endif + return std::string(); + } + + socket_t create_client_socket(const char* host, int port, + time_t timeout_sec, + const std::string& intf) { + return create_socket( + host, port, [&](socket_t sock, struct addrinfo& ai) -> bool { + if (!intf.empty()) { + auto ip = if2ip(intf); + if (ip.empty()) { ip = intf; } + if (!bind_ip_address(sock, ip.c_str())) { return false; } + } + + set_nonblocking(sock, true); + + auto ret = + ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, 0)) { + close_socket(sock); + return false; + } + } + + set_nonblocking(sock, false); + return true; + }); + } + + std::string get_remote_addr(socket_t sock) { + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), &len)) { + std::array ipstr{}; + + if (!getnameinfo(reinterpret_cast(&addr), len, + ipstr.data(), static_cast(ipstr.size()), + nullptr, 0, NI_NUMERICHOST)) { + return ipstr.data(); + } + } + + return std::string(); + } + + const char* + find_content_type(const std::string& path, + const std::map& user_data) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second.c_str(); } + + if (ext == "txt") { + return "text/plain"; + } + else if (ext == "html" || ext == "htm") { + return "text/html"; + } + else if (ext == "css") { + return "text/css"; + } + else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } + else if (ext == "png") { + return "image/png"; + } + else if (ext == "gif") { + return "image/gif"; + } + else if (ext == "svg") { + return "image/svg+xml"; + } + else if (ext == "ico") { + return "image/x-icon"; + } + else if (ext == "json") { + return "application/json"; + } + else if (ext == "pdf") { + return "application/pdf"; + } + else if (ext == "js") { + return "application/javascript"; + } + else if (ext == "wasm") { + return "application/wasm"; + } + else if (ext == "xml") { + return "application/xml"; + } + else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; + } + + const char* status_message(int status) { + switch (status) { + case 100: return "Continue"; + case 200: return "OK"; + case 202: return "Accepted"; + case 204: return "No Content"; + case 206: return "Partial Content"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 303: return "See Other"; + case 304: return "Not Modified"; + case 400: return "Bad Request"; + case 401: return "Unauthorized"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 413: return "Payload Too Large"; + case 414: return "Request-URI Too Long"; + case 415: return "Unsupported Media Type"; + case 416: return "Range Not Satisfiable"; + case 417: return "Expectation Failed"; + case 503: return "Service Unavailable"; + + default: + case 500: return "Internal Server Error"; + } + } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + bool can_compress(const std::string& content_type) { + return !content_type.find("text/") || content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; + } + + bool compress(std::string& content) { + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY); + if (ret != Z_OK) { return false; } + + strm.avail_in = static_cast(content.size()); + strm.next_in = + const_cast(reinterpret_cast(content.data())); + + std::string compressed; + + std::array buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + ret = deflate(&strm, Z_FINISH); + assert(ret != Z_STREAM_ERROR); + compressed.append(buff.data(), buff.size() - strm.avail_out); + } while (strm.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm.avail_in == 0); + + content.swap(compressed); + + deflateEnd(&strm); + return true; + } + + class decompressor { + public: + decompressor() { + std::memset(&strm, 0, sizeof(strm)); + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm, 32 + 15) == Z_OK; + } + + ~decompressor() { inflateEnd(&strm); } + + bool is_valid() const { return is_valid_; } + + template + bool decompress(const char* data, size_t data_length, T callback) { + int ret = Z_OK; + + strm.avail_in = static_cast(data_length); + strm.next_in = const_cast(reinterpret_cast(data)); + + std::array buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm); return false; + } + + if (!callback(buff.data(), buff.size() - strm.avail_out)) { + return false; + } + } while (strm.avail_out == 0); + + return ret == Z_OK || ret == Z_STREAM_END; + } + + private: + bool is_valid_; + z_stream strm; + }; +#endif + + bool has_header(const Headers& headers, const char* key) { + return headers.find(key) != headers.end(); + } + + const char* get_header_value(const Headers& headers, const char* key, + size_t id = 0, const char* def = nullptr) { + auto it = headers.find(key); + std::advance(it, static_cast(id)); + if (it != headers.end()) { return it->second.c_str(); } + return def; + } + + uint64_t get_header_value_uint64(const Headers& headers, const char* key, + uint64_t def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; + } + + bool read_headers(Stream& strm, Headers& headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } + else { + continue; // Skip invalid line. + } + + // Skip trailing spaces and tabs. + auto end = line_reader.ptr() + line_reader.size() - 2; + while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) { + end--; + } + + // Horizontal tab and ' ' are considered whitespace and are ignored when on + // the left or right side of the header value: + // - https://stackoverflow.com/questions/50179659/ + // - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html + static const std::regex re(R"(([^:]+):[\t ]*(.+))"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), end, m, re)) { + auto key = std::string(m[1]); + auto val = std::string(m[2]); + headers.emplace(key, val); + } + } + + return true; + } + + bool read_content_with_length(Stream& strm, uint64_t len, + Progress progress, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n))) { return false; } + + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; + } + + void skip_content_with_length(Stream& strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } + } + + bool read_content_without_length(Stream& strm, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } + else if (n == 0) { + return true; + } + if (!out(buf, static_cast(n))) { return false; } + } + + return true; + } + + bool read_content_chunked(Stream& strm, ContentReceiver out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + auto chunk_len = std::stoul(line_reader.ptr(), 0, 16); + + while (chunk_len > 0) { + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n")) { break; } + + if (!line_reader.getline()) { return false; } + + chunk_len = std::stoul(line_reader.ptr(), 0, 16); + } + + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } + + return true; + } + + bool is_chunked_transfer_encoding(const Headers& headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); + } + + template + bool read_content(Stream& strm, T& x, size_t payload_max_length, int& status, + Progress progress, ContentReceiver receiver) { + + ContentReceiver out = [&](const char* buf, size_t n) { + return receiver(buf, n); + }; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor decompressor; + + std::string content_encoding = x.get_header_value("Content-Encoding"); + if (content_encoding.find("gzip") != std::string::npos || + content_encoding.find("deflate") != std::string::npos) { + if (!decompressor.is_valid()) { + status = 500; + return false; + } + + out = [&](const char* buf, size_t n) { + return decompressor.decompress( + buf, n, [&](const char* buf, size_t n) { return receiver(buf, n); }); + }; + } +#else + if (x.get_header_value("Content-Encoding") == "gzip") { + status = 415; + return false; + } +#endif + + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } + else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } + else { + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } + else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } + + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + + return ret; + } + + template + ssize_t write_headers(Stream& strm, const T& info, + const Headers& headers) { + ssize_t write_len = 0; + for (const auto& x : info.headers) { + if (x.first == "EXCEPTION_WHAT") { continue; } + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + for (const auto& x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; + } + + ssize_t write_content(Stream& strm, ContentProvider content_provider, + size_t offset, size_t length) { + size_t begin_offset = offset; + size_t end_offset = offset + length; + while (offset < end_offset) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char* d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }; + data_sink.done = [&](void) { written_length = -1; }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, end_offset - offset, data_sink); + if (written_length < 0) { return written_length; } + } + return static_cast(offset - begin_offset); + } + + template + ssize_t write_content_chunked(Stream& strm, + ContentProvider content_provider, + T is_shutting_down) { + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while (data_available && !is_shutting_down()) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char* d, size_t l) { + data_available = l > 0; + offset += l; + + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }; + data_sink.done = [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, 0, data_sink); + + if (written_length < 0) { return written_length; } + total_written_length += written_length; + } + return total_written_length; + } + + template + bool redirect(T& cli, const Request& req, Response& res, + const std::string& path) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + Response new_res; + + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; + } + + std::string encode_url(const std::string& s) { + std::string result; + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } + else { + result += s[i]; + } + break; + } + } + + return result; + } + + std::string decode_url(const std::string& s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } + else { + result += s[i]; + } + } + else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } + else { + result += s[i]; + } + } + } + else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } + else { + result += s[i]; + } + } + + return result; + } + + std::string params_to_query_str(const Params& params) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return query; + } + + void parse_query_text(const std::string& s, Params& params) { + split(&s[0], &s[s.size()], '&', [&](const char* b, const char* e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char* b2, const char* e2) { + if (key.empty()) { + key.assign(b2, e2); + } + else { + val.assign(b2, e2); + } + }); + params.emplace(decode_url(key, true), decode_url(val, true)); + }); + } + + bool parse_multipart_boundary(const std::string& content_type, + std::string& boundary) { + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { return false; } + + boundary = content_type.substr(pos + 9); + return true; + } + + bool parse_range_header(const std::string& s, Ranges& ranges) { + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char* b, const char* e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; + } + + class MultipartFormDataParser { + public: + MultipartFormDataParser() {} + + void set_boundary(const std::string& boundary) { boundary_ = boundary; } + + bool is_valid() const { return is_valid_; } + + template + bool parse(const char* buf, size_t n, T content_callback, U header_callback) { + static const std::regex re_content_type(R"(^Content-Type:\s*(.*?)\s*$)", + std::regex_constants::icase); + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + auto pos = buf_.find(pattern); + if (pos != 0) { + is_done_ = true; + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + is_done_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + auto header = buf_.substr(0, pos); + { + std::smatch m; + if (std::regex_match(header, m, re_content_type)) { + file_.content_type = m[1]; + } + else if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { pos = buf_.size(); } + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } + else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { return true; } + if (buf_.find(crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } + else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + if (buf_.find(pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } + else { + is_done_ = true; + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; + } + } + } + + return true; + } + + private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + size_t is_valid_ = false; + size_t is_done_ = false; + size_t off_ = 0; + MultipartFormData file_; + }; + + std::string to_lower(const char* beg, const char* end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; + } + + std::string make_multipart_data_boundary() { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; + } + + std::pair + get_range_offset_and_length(const Request& req, size_t content_length, + size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + auto slen = static_cast(content_length); + + if (r.first == -1) { + r.first = slen - r.second; + r.second = slen - 1; + } + + if (r.second == -1) { r.second = slen - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); + } + + std::string make_content_range_header_field(size_t offset, size_t length, + size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; + } + + template + bool process_multipart_ranges_data(const Request& req, Response& res, + const std::string& boundary, + const std::string& content_type, + SToken stoken, CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { return false; } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; + } + + std::string make_multipart_ranges_data(const Request& req, Response& res, + const std::string& boundary, + const std::string& content_type) { + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string& token) { data += token; }, + [&](const char* token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; + } + + size_t + get_multipart_ranges_data_length(const Request& req, Response& res, + const std::string& boundary, + const std::string& content_type) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string& token) { data_length += token.size(); }, + [&](const char* token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; + } + + bool write_multipart_ranges_data(Stream& strm, const Request& req, + Response& res, + const std::string& boundary, + const std::string& content_type) { + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string& token) { strm.write(token); }, + [&](const char* token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider, offset, length) >= 0; + }); + } + + std::pair + get_range_offset_and_length(const Request& req, const Response& res, + size_t index) { + auto r = req.ranges[index]; + + if (r.second == -1) { + r.second = static_cast(res.content_length) - 1; + } + + return std::make_pair(r.first, r.second - r.first + 1); + } + + bool expect_content(const Request& req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI") { + return true; + } + // TODO: check if Content-Length is set + return false; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + template + std::string message_digest(const std::string& s, Init init, + Update update, Final final, + size_t digest_length) { + using namespace std; + + std::vector md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); + + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); + } + + std::string MD5(const std::string& s) { + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); + } + + std::string SHA_256(const std::string& s) { + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); + } + + std::string SHA_512(const std::string& s) { + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); + } +#endif + +#ifdef _WIN32 + class WSInit { + public: + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } + + ~WSInit() { WSACleanup(); } + }; + + static WSInit wsinit_; +#endif + + } // namespace detail + + // Header utilities + std::pair make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", field); + } + + std::pair + make_basic_authentication_header(const std::string& username, + const std::string& password, + bool is_proxy = false) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::pair make_digest_authentication_header( + const Request& req, const std::map& auth, + size_t cnonce_count, const std::string& cnonce, const std::string& username, + const std::string& password, bool is_proxy = false) { + using namespace std; + + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } + + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } + else { + qop = "auth"; + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + string response; + { + auto H = algo == "SHA-256" + ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + + auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") + + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + + "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc + + "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\""; + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); + } +#endif + + bool parse_www_authenticate(const httplib::Response& res, + std::map& auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } + else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; + } + + // https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 + std::string random_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[static_cast(rand()) % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; + } + + // Request implementation + bool Request::has_header(const char* key) const { + return detail::has_header(headers, key); + } + + std::string Request::get_header_value(const char* key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); + } + + size_t Request::get_header_value_count(const char* key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); + } + + void Request::set_header(const char* key, const char* val) { + headers.emplace(key, val); + } + + void Request::set_header(const char* key, const std::string& val) { + headers.emplace(key, val); + } + + bool Request::has_param(const char* key) const { + return params.find(key) != params.end(); + } + + std::string Request::get_param_value(const char* key, size_t id) const { + auto it = params.find(key); + std::advance(it, static_cast(id)); + if (it != params.end()) { return it->second; } + return std::string(); + } + + size_t Request::get_param_value_count(const char* key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); + } + + bool Request::is_multipart_form_data() const { + const auto& content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); + } + + bool Request::has_file(const char* key) const { + return files.find(key) != files.end(); + } + + MultipartFormData Request::get_file_value(const char* key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); + } + + // Response implementation + bool Response::has_header(const char* key) const { + return headers.find(key) != headers.end(); + } + + std::string Response::get_header_value(const char* key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); + } + + size_t Response::get_header_value_count(const char* key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); + } + + void Response::set_header(const char* key, const char* val) { + headers.emplace(key, val); + } + + void Response::set_header(const char* key, const std::string& val) { + headers.emplace(key, val); + } + + void Response::set_redirect(const char* url) { + set_header("Location", url); + status = 302; + } + + void Response::set_content(const char* s, size_t n, + const char* content_type) { + body.assign(s, n); + set_header("Content-Type", content_type); + } + + void Response::set_content(const std::string& s, + const char* content_type) { + body = s; + set_header("Content-Type", content_type); + } + + void Response::set_content_provider( + size_t in_length, + std::function provider, + std::function resource_releaser) { + assert(in_length > 0); + content_length = in_length; + content_provider = [provider](size_t offset, size_t length, DataSink& sink) { + provider(offset, length, sink); + }; + content_provider_resource_releaser = resource_releaser; + } + + void Response::set_chunked_content_provider( + std::function provider, + std::function resource_releaser) { + content_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink& sink) { + provider(offset, sink); + }; + content_provider_resource_releaser = resource_releaser; + } + + // Rstream implementation + ssize_t Stream::write(const char* ptr) { + return write(ptr, strlen(ptr)); + } + + ssize_t Stream::write(const std::string& s) { + return write(s.data(), s.size()); + } + + template + ssize_t Stream::write_format(const char* fmt, const Args&... args) { + std::array buf; + +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto sn = _snprintf_s(buf, bufsiz, buf.size() - 1, fmt, args...); +#else + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); +#endif + if (sn <= 0) { return sn; } + + auto n = static_cast(sn); + + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); + + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); +#if defined(_MSC_VER) && _MSC_VER < 1900 + n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, + args...)); +#else + n = static_cast( + snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); +#endif + } + return write(&glowable_buf[0], n); + } + else { + return write(buf.data(), n); + } + } + + namespace detail { + + // Socket stream implementation + SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + + SocketStream::~SocketStream() {} + + bool SocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + bool SocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; + } + + ssize_t SocketStream::read(char* ptr, size_t size) { + if (is_readable()) { return recv(sock_, ptr, size, 0); } + return -1; + } + + ssize_t SocketStream::write(const char* ptr, size_t size) { + if (is_writable()) { return send(sock_, ptr, size, 0); } + return -1; + } + + std::string SocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); + } + + // Buffer stream implementation + bool BufferStream::is_readable() const { return true; } + + bool BufferStream::is_writable() const { return true; } + + ssize_t BufferStream::read(char* ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); + } + + ssize_t BufferStream::write(const char* ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); + } + + std::string BufferStream::get_remote_addr() const { return ""; } + + const std::string& BufferStream::get_buffer() const { return buffer; } + + } // namespace detail + + // HTTP server implementation + Server::Server() + : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), + read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), + payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), + svr_sock_(INVALID_SOCKET) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif + new_task_queue = [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }; + } + + Server::~Server() {} + + Server& Server::Get(const char* pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Post(const char* pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Post(const char* pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Put(const char* pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Put(const char* pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Patch(const char* pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Patch(const char* pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Delete(const char* pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + Server& Server::Options(const char* pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + bool Server::set_base_dir(const char* dir, const char* mount_point) { + return set_mount_point(mount_point, dir); + } + + bool Server::set_mount_point(const char* mount_point, const char* dir) { + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; + } + } + return false; + } + + bool Server::remove_mount_point(const char* mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; + } + + void Server::set_file_extension_and_mimetype_mapping(const char* ext, + const char* mime) { + file_extension_and_mimetype_map_[ext] = mime; + } + + void Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + } + + void Server::set_error_handler(Handler handler) { + error_handler_ = std::move(handler); + } + + void Server::set_logger(Logger logger) { logger_ = std::move(logger); } + + void + Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + } + + void Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + } + + void Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + } + + void Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + } + + bool Server::bind_to_port(const char* host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; + } + int Server::bind_to_any_port(const char* host, int socket_flags) { + return bind_internal(host, 0, socket_flags); + } + + bool Server::listen_after_bind() { return listen_internal(); } + + bool Server::listen(const char* host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); + } + + bool Server::is_running() const { return is_running_; } + + void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + bool Server::parse_request_line(const char* s, Request& req) { + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3], false); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { detail::parse_query_text(m[4], req.params); } + + return true; + } + + return false; + } + + bool Server::write_response(Stream& strm, bool last_connection, + const Request& req, Response& res) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_) { error_handler_(req, res); } + + detail::BufferStream bstrm; + + // Response line + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } + + // Headers + if (last_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } + + if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + res.set_header("Connection", "Keep-Alive"); + } + + if (!res.has_header("Content-Type") && + (!res.body.empty() || res.content_length > 0)) { + res.set_header("Content-Type", "text/plain"); + } + + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } + + std::string content_type; + std::string boundary; + + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); + + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + if (res.body.empty()) { + if (res.content_length > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length; + } + else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length); + res.set_header("Content-Range", content_range); + } + else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } + else { + if (res.content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + } + else { + res.set_header("Content-Length", "0"); + } + } + } + else { + if (req.ranges.empty()) { + ; + } + else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } + else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); + } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + const auto& encodings = req.get_header_value("Accept-Encoding"); + if (encodings.find("gzip") != std::string::npos && + detail::can_compress(res.get_header_value("Content-Type"))) { + if (detail::compress(res.body)) { + res.set_header("Content-Encoding", "gzip"); + } + } +#endif + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } + + if (!detail::write_headers(bstrm, res, Headers())) { return false; } + + // Flush buffer + auto& data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { return false; } + } + else if (res.content_provider) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; + } + + bool + Server::write_content_with_provider(Stream& strm, const Request& req, + Response& res, const std::string& boundary, + const std::string& content_type) { + if (res.content_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_length) < 0) { + return false; + } + } + else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } + else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } + } + } + else { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + if (detail::write_content_chunked(strm, res.content_provider, + is_shutting_down) < 0) { + return false; + } + } + return true; + } + + bool Server::read_content(Stream& strm, bool last_connection, + Request& req, Response& res) { + MultipartFormDataMap::iterator cur; + auto ret = read_content_core( + strm, last_connection, req, res, + // Regular + [&](const char* buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData& file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char* buf, size_t n) { + auto& content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + }); + + const auto& content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + + return ret; + } + + bool Server::read_content_with_content_receiver( + Stream& strm, bool last_connection, Request& req, Response& res, + ContentReceiver receiver, MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, last_connection, req, res, receiver, + multipart_header, multipart_receiver); + } + + bool Server::read_content_core(Stream& strm, bool last_connection, + Request& req, Response& res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver) { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto& content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + multipart_form_data_parser.set_boundary(boundary); + out = [&](const char* buf, size_t n) { + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } + else { + out = receiver; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), out)) { + return write_response(strm, last_connection, req, res); + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + } + + return true; + } + + bool Server::handle_file_request(Request& req, Response& res, + bool head) { + for (const auto& kv : base_dirs_) { + const auto& mount_point = kv.first; + const auto& base_dir = kv.second; + + // Prefix match + if (!req.path.find(mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { res.set_header("Content-Type", type); } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; + } + } + } + } + return false; + } + + socket_t Server::create_server_socket(const char* host, int port, + int socket_flags) const { + return detail::create_socket( + host, port, + [](socket_t sock, struct addrinfo& ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }, + socket_flags); + } + + int Server::bind_internal(const char* host, int port, int socket_flags) { + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage address; + socklen_t len = sizeof(address); + if (getsockname(svr_sock_, reinterpret_cast(&address), + &len) == -1) { + return -1; + } + if (address.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&address)->sin_port); + } + else if (address.ss_family == AF_INET6) { + return ntohs( + reinterpret_cast(&address)->sin6_port); + } + else { + return -1; + } + } + else { + return port; + } + } + + bool Server::listen_internal() { + auto ret = true; + is_running_ = true; + + { + std::unique_ptr task_queue(new_task_queue()); + + for (;;) { + if (svr_sock_ == INVALID_SOCKET) { + // The server socket was closed by 'stop' method. + break; + } + + auto val = detail::select_read(svr_sock_, 0, 100000); + + if (val == 0) { // Timeout + continue; + } + + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } + else { + ; // The server socket was closed by user. + } + break; + } + +#if __cplusplus > 201703L + task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); +#else + task_queue->enqueue([=]() { process_and_close_socket(sock); }); +#endif + } + + task_queue->shutdown(); + } + + is_running_ = false; + return ret; + } + + bool Server::routing(Request& req, Response& res, Stream& strm, + bool last_connection) { + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, last_connection, req, res, receiver, nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, last_connection, req, res, nullptr, header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } + else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } + else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, last_connection, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } + else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } + else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } + else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } + else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } + else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; + } + + bool Server::dispatch_request(Request& req, Response& res, + Handlers& handlers) { + + try { + for (const auto& x : handlers) { + const auto& pattern = x.first; + const auto& handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + } + catch (const std::exception & ex) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", ex.what()); + } + catch (...) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + return false; + } + + bool Server::dispatch_request_for_content_reader( + Request& req, Response& res, ContentReader content_reader, + HandlersForContentReader& handlers) { + for (const auto& x : handlers) { + const auto& pattern = x.first; + const auto& handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } + } + return false; + } + + bool + Server::process_request(Stream& strm, bool last_connection, + bool& connection_close, + const std::function& setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, last_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_close = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_close = true; + } + + req.set_header("REMOTE_ADDR", strm.get_remote_addr()); + + if (req.has_header("Range")) { + const auto& range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + auto status = 100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case 100: + case 417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, + detail::status_message(status)); + break; + default: return write_response(strm, last_connection, req, res); + } + } + + // Rounting + if (routing(req, res, strm, last_connection)) { + if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } + } + else { + if (res.status == -1) { res.status = 404; } + } + + return write_response(strm, last_connection, req, res); + } + + bool Server::is_valid() const { return true; } + + bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + [this](Stream& strm, bool last_connection, bool& connection_close) { + return process_request(strm, last_connection, connection_close, + nullptr); + }); + } + + // HTTP client implementation + Client::Client(const std::string& host, int port, + const std::string& client_cert_path, + const std::string& client_key_path) + : host_(host), port_(port), + host_and_port_(host_ + ":" + std::to_string(port_)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + + Client::~Client() {} + + bool Client::is_valid() const { return true; } + + socket_t Client::create_client_socket() const { + if (!proxy_host_.empty()) { + return detail::create_client_socket(proxy_host_.c_str(), proxy_port_, + timeout_sec_, interface_); + } + return detail::create_client_socket(host_.c_str(), port_, timeout_sec_, + interface_); + } + + bool Client::read_response_line(Stream& strm, Response& res) { + std::array buf; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), m, re)) { + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + } + + return true; + } + + bool Client::send(const Request& req, Response& res) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + bool error; + if (!connect(sock, res, error)) { return error; } + } +#endif + + return process_and_close_socket( + sock, 1, [&](Stream& strm, bool last_connection, bool& connection_close) { + return handle_request(strm, req, res, last_connection, + connection_close); + }); + } + + bool Client::send(const std::vector& requests, + std::vector& responses) { + size_t i = 0; + while (i < requests.size()) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + Response res; + bool error; + if (!connect(sock, res, error)) { return false; } + } +#endif + + if (!process_and_close_socket(sock, requests.size() - i, + [&](Stream& strm, bool last_connection, + bool& connection_close) -> bool { + auto& req = requests[i++]; + auto res = Response(); + auto ret = handle_request(strm, req, res, + last_connection, + connection_close); + if (ret) { + responses.emplace_back(std::move(res)); + } + return ret; + })) { + return false; + } + } + + return true; + } + + bool Client::handle_request(Stream& strm, const Request& req, + Response& res, bool last_connection, + bool& connection_close) { + if (req.path.empty()) { return false; } + + bool ret; + + if (!is_ssl() && !proxy_host_.empty()) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, last_connection, connection_close); + } + else { + ret = process_request(strm, req, res, last_connection, connection_close); + } + + if (!ret) { return false; } + + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (res.status == 401 || res.status == 407) { + auto is_proxy = res.status == 407; + const auto& username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto& password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + auto key = is_proxy ? "Proxy-Authorization" : "WWW-Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(make_digest_authentication_header( + req, auth, 1, random_string(10), username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool Client::connect(socket_t sock, Response& res, bool& error) { + error = true; + Response res2; + + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream& strm, bool /*last_connection*/, bool& connection_close) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false, connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream& strm, bool /*last_connection*/, + bool& connection_close) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(make_digest_authentication_header( + req3, auth, 1, random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false, + connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + } + } + else { + res = res2; + return false; + } + } + + return true; + } +#endif + + bool Client::redirect(const Request& req, Response& res) { + if (req.redirect_count == 0) { return false; } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + const static std::regex re( + R"(^(?:([^:/?#]+):)?(?://([^/?#]*))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + std::smatch m; + if (!regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto next_path = m[3].str(); + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == scheme && next_host == host_) { + return detail::redirect(*this, req, res, next_path); + } + else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str()); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); +#else + return false; +#endif + } + else { + Client cli(next_host.c_str()); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); + } + } + } + + bool Client::write_request(Stream& strm, const Request& req, + bool last_connection) { + detail::BufferStream bstrm; + + // Request line + const auto& path = detail::encode_url(req.path); + + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } + else { + headers.emplace("Host", host_and_port_); + } + } + else { + if (port_ == 80) { + headers.emplace("Host", host_); + } + else { + headers.emplace("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } + + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.5"); + } + + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } + else { + headers.emplace("Content-Length", "0"); + } + } + else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } + + if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + + detail::write_headers(bstrm, req, headers); + + // Flush buffer + auto& data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + DataSink data_sink; + data_sink.write = [&](const char* d, size_t l) { + auto written_length = strm.write(d, l); + offset += static_cast(written_length); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + while (offset < end_offset) { + req.content_provider(offset, end_offset - offset, data_sink); + } + } + } + else { + strm.write(req.body); + } + + return true; + } + + std::shared_ptr Client::send_with_content_provider( + const char* method, const char* path, const Headers& headers, + const std::string& body, size_t content_length, + ContentProvider content_provider, const char* content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + + req.headers.emplace("Content-Type", content_type); + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + if (content_provider) { + size_t offset = 0; + + DataSink data_sink; + data_sink.write = [&](const char* data, size_t data_len) { + req.body.append(data, data_len); + offset += data_len; + }; + data_sink.is_writable = [&](void) { return true; }; + + while (offset < content_length) { + content_provider(offset, content_length - offset, data_sink); + } + } + else { + req.body = body; + } + + if (!detail::compress(req.body)) { return nullptr; } + req.headers.emplace("Content-Encoding", "gzip"); + } + else +#endif + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } + else { + req.body = body; + } + } + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + bool Client::process_request(Stream& strm, const Request& req, + Response& res, bool last_connection, + bool& connection_close) { + // Send request + if (!write_request(strm, req, last_connection)) { return false; } + + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + return false; + } + + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + connection_close = true; + } + + if (req.response_handler) { + if (!req.response_handler(res)) { return false; } + } + + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + ContentReceiver out = [&](const char* buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { return false; } + res.body.append(buf, n); + return true; + }; + + if (req.content_receiver) { + out = [&](const char* buf, size_t n) { + return req.content_receiver(buf, n); + }; + } + + int dummy_status; + if (!detail::read_content(strm, res, std::numeric_limits::max(), + dummy_status, req.progress, out)) { + return false; + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; + } + + bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + request_count = std::min(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, + read_timeout_sec_, read_timeout_usec_, + callback); + } + + bool Client::is_ssl() const { return false; } + + std::shared_ptr Client::Get(const char* path) { + return Get(path, Headers(), Progress()); + } + + std::shared_ptr Client::Get(const char* path, + Progress progress) { + return Get(path, Headers(), std::move(progress)); + } + + std::shared_ptr Client::Get(const char* path, + const Headers& headers) { + return Get(path, headers, Progress()); + } + + std::shared_ptr + Client::Get(const char* path, const Headers& headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; + } + + std::shared_ptr Client::Get(const char* path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), Progress()); + } + + std::shared_ptr Client::Get(const char* path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); + } + + std::shared_ptr Client::Get(const char* path, + const Headers& headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), Progress()); + } + + std::shared_ptr Client::Get(const char* path, + const Headers& headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); + } + + std::shared_ptr Client::Get(const char* path, + const Headers& headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), content_receiver, + Progress()); + } + + std::shared_ptr Client::Get(const char* path, + const Headers& headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); + + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; + } + + std::shared_ptr Client::Head(const char* path) { + return Head(path, Headers()); + } + + std::shared_ptr Client::Head(const char* path, + const Headers& headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + std::shared_ptr Client::Post(const char* path, + const std::string& body, + const char* content_type) { + return Post(path, Headers(), body, content_type); + } + + std::shared_ptr Client::Post(const char* path, + const Headers& headers, + const std::string& body, + const char* content_type) { + return send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); + } + + std::shared_ptr Client::Post(const char* path, + const Params& params) { + return Post(path, Headers(), params); + } + + std::shared_ptr Client::Post(const char* path, + size_t content_length, + ContentProvider content_provider, + const char* content_type) { + return Post(path, Headers(), content_length, content_provider, content_type); + } + + std::shared_ptr + Client::Post(const char* path, const Headers& headers, size_t content_length, + ContentProvider content_provider, const char* content_type) { + return send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); + } + + std::shared_ptr + Client::Post(const char* path, const Headers& headers, const Params& params) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); + } + + std::shared_ptr + Client::Post(const char* path, const MultipartFormDataItems& items) { + return Post(path, Headers(), items); + } + + std::shared_ptr + Client::Post(const char* path, const Headers& headers, + const MultipartFormDataItems& items) { + auto boundary = detail::make_multipart_data_boundary(); + + std::string body; + + for (const auto& item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; + } + + body += "--" + boundary + "--\r\n"; + + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); + } + + std::shared_ptr Client::Put(const char* path, + const std::string& body, + const char* content_type) { + return Put(path, Headers(), body, content_type); + } + + std::shared_ptr Client::Put(const char* path, + const Headers& headers, + const std::string& body, + const char* content_type) { + return send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); + } + + std::shared_ptr Client::Put(const char* path, + size_t content_length, + ContentProvider content_provider, + const char* content_type) { + return Put(path, Headers(), content_length, content_provider, content_type); + } + + std::shared_ptr + Client::Put(const char* path, const Headers& headers, size_t content_length, + ContentProvider content_provider, const char* content_type) { + return send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); + } + + std::shared_ptr Client::Put(const char* path, + const Params& params) { + return Put(path, Headers(), params); + } + + std::shared_ptr + Client::Put(const char* path, const Headers& headers, const Params& params) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); + } + + std::shared_ptr Client::Patch(const char* path, + const std::string& body, + const char* content_type) { + return Patch(path, Headers(), body, content_type); + } + + std::shared_ptr Client::Patch(const char* path, + const Headers& headers, + const std::string& body, + const char* content_type) { + return send_with_content_provider("PATCH", path, headers, body, 0, nullptr, + content_type); + } + + std::shared_ptr Client::Patch(const char* path, + size_t content_length, + ContentProvider content_provider, + const char* content_type) { + return Patch(path, Headers(), content_length, content_provider, content_type); + } + + std::shared_ptr + Client::Patch(const char* path, const Headers& headers, size_t content_length, + ContentProvider content_provider, const char* content_type) { + return send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); + } + + std::shared_ptr Client::Delete(const char* path) { + return Delete(path, Headers(), std::string(), nullptr); + } + + std::shared_ptr Client::Delete(const char* path, + const std::string& body, + const char* content_type) { + return Delete(path, Headers(), body, content_type); + } + + std::shared_ptr Client::Delete(const char* path, + const Headers& headers) { + return Delete(path, headers, std::string(), nullptr); + } + + std::shared_ptr Client::Delete(const char* path, + const Headers& headers, + const std::string& body, + const char* content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + std::shared_ptr Client::Options(const char* path) { + return Options(path, Headers()); + } + + std::shared_ptr Client::Options(const char* path, + const Headers& headers) { + Request req; + req.method = "OPTIONS"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + void Client::set_timeout_sec(time_t timeout_sec) { + timeout_sec_ = timeout_sec; + } + + void Client::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + } + + void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + } + + void Client::set_basic_auth(const char* username, const char* password) { + basic_auth_username_ = username; + basic_auth_password_ = password; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void Client::set_digest_auth(const char* username, + const char* password) { + digest_auth_username_ = username; + digest_auth_password_ = password; + } +#endif + + void Client::set_follow_location(bool on) { follow_location_ = on; } + + void Client::set_compress(bool on) { compress_ = on; } + + void Client::set_interface(const char* intf) { interface_ = intf; } + + void Client::set_proxy(const char* host, int port) { + proxy_host_ = host; + proxy_port_ = port; + } + + void Client::set_proxy_basic_auth(const char* username, + const char* password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void Client::set_proxy_digest_auth(const char* username, + const char* password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; + } +#endif + + void Client::set_logger(Logger logger) { logger_ = std::move(logger); } + + /* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + namespace detail { + + template + bool process_and_close_socket_ssl( + bool is_client_request, socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX* ctx, + std::mutex& ctx_mutex, U SSL_connect_or_accept, V setup, T callback) { + assert(keep_alive_max_count > 0); + + SSL* ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (!ssl) { + close_socket(sock); + return false; + } + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl)) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + return false; + } + + auto ret = false; + + if (SSL_connect_or_accept(ssl) == 1) { + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(ssl, strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } + else { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); + } + } + + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + + return ret; + } + +#if OPENSSL_VERSION_NUMBER < 0x10100000L + static std::shared_ptr> openSSL_locks_; + + class SSLThreadLocks { + public: + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } + + ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + + private: + static void locking_callback(int mode, int type, const char* /*file*/, + int /*line*/) { + auto& locks = *openSSL_locks_; + if (mode & CRYPTO_LOCK) { + locks[type].lock(); + } + else { + locks[type].unlock(); + } + } + }; + +#endif + + class SSLInit { + public: + SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + SSL_load_error_strings(); + SSL_library_init(); +#else + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); +#endif + } + + ~SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + ERR_free_strings(); +#endif + } + + private: +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSLThreadLocks thread_init_; +#endif + }; + + // SSL socket stream implementation + SSLSocketStream::SSLSocketStream(socket_t sock, SSL* ssl, + time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + + SSLSocketStream::~SSLSocketStream() {} + + bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + bool SSLSocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; + } + + ssize_t SSLSocketStream::read(char* ptr, size_t size) { + if (SSL_pending(ssl_) > 0 || + select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; + } + + ssize_t SSLSocketStream::write(const char* ptr, size_t size) { + if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } + return -1; + } + + std::string SSLSocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); + } + + static SSLInit sslinit_; + + } // namespace detail + + // SSL HTTP server implementation + SSLServer::SSLServer(const char* cert_path, const char* private_key_path, + const char* client_ca_cert_file_path, + const char* client_ca_cert_dir_path) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } + } + } + + SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } + } + + bool SSLServer::is_valid() const { return ctx_; } + + bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, SSL_accept, [](SSL* /*ssl*/) { return true; }, + [this](SSL* ssl, Stream& strm, bool last_connection, + bool& connection_close) { + return process_request(strm, last_connection, connection_close, + [&](Request& req) { req.ssl = ssl; }); + }); + } + + // SSL HTTP client implementation + SSLClient::SSLClient(const std::string& host, int port, + const std::string& client_cert_path, + const std::string& client_key_path) + : Client(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char* b, const char* e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } + } + + SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } + } + + bool SSLClient::is_valid() const { return ctx_; } + + void SSLClient::set_ca_cert_path(const char* ca_cert_file_path, + const char* ca_cert_dir_path) { + if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } + if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } + } + + void SSLClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; + } + + long SSLClient::get_openssl_verify_result() const { + return verify_result_; + } + + SSL_CTX* SSLClient::ssl_context() const noexcept { return ctx_; } + + bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + + request_count = std::min(request_count, keep_alive_max_count_); + + return is_valid() && + detail::process_and_close_socket_ssl( + true, sock, request_count, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, + [&](SSL* ssl) { + if (ca_cert_file_path_.empty()) { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); + } + else { + if (!SSL_CTX_load_verify_locations( + ctx_, ca_cert_file_path_.c_str(), nullptr)) { + return false; + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } + + if (SSL_connect(ssl) != 1) { return false; } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); + + if (verify_result_ != X509_V_OK) { return false; } + + auto server_cert = SSL_get_peer_certificate(ssl); + + if (server_cert == nullptr) { return false; } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL* ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }, + [&](SSL* /*ssl*/, Stream& strm, bool last_connection, + bool& connection_close) { + return callback(strm, last_connection, connection_close); + }); + } + + bool SSLClient::is_ssl() const { return true; } + + bool SSLClient::verify_host(X509* server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); + } + + bool + SSLClient::verify_host_with_subject_alt_name(X509* server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } + else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (auto i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char*)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } + } + } + + if (dsn_matched || ip_mached) { ret = true; } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME)*)alt_names); + + return ret; + } + + bool SSLClient::verify_host_with_common_name(X509* server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; + } + + bool SSLClient::check_host_name(const char* pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char* b, const char* e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto& h : host_components_) { + auto& p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; + } +#endif + +} // namespace httplib diff --git a/web/httplib.h b/web/httplib.h new file mode 100644 index 0000000..1861cd9 --- /dev/null +++ b/web/httplib.h @@ -0,0 +1,908 @@ +// +// httplib.h +// +// Copyright (c) 2020 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +/* + * Configuration + */ + +#define CPPHTTPLIB_ZLIB_SUPPORT +#define CPPHTTPLIB_OPENSSL_SUPPORT + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits::max()) +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(8192u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT 64 +// (std::max(1u, std::thread::hardware_concurrency() - 1)) +#endif + + /* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = int; +#endif + +#if _MSC_VER < 1900 +#define snprintf _snprintf_s +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER +#pragma comment(lib, "ws2_32.lib") +#endif + +#ifndef strcasecmp +#define strcasecmp _stricmp +#endif // strcasecmp + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#include +#include +#include +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include + +using socket_t = int; +#define INVALID_SOCKET (-1) +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include +#include +#include +#include + +#include +#include + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include +inline const unsigned char* ASN1_STRING_get0_data(const ASN1_STRING* asn1) { + return M_ASN1_STRING_data(asn1); +} +#endif +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +/* + * Declaration + */ +namespace httplib { + + namespace detail { + + struct ci { + bool operator()(const std::string& s1, const std::string& s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } + }; + + } // namespace detail + + using Headers = std::multimap; + + using Params = std::multimap; + using Match = std::smatch; + + using Progress = std::function; + + struct Response; + using ResponseHandler = std::function; + + struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; + }; + using MultipartFormDataItems = std::vector; + using MultipartFormDataMap = std::multimap; + + class DataSink { + public: + DataSink() = default; + DataSink(const DataSink&) = delete; + DataSink& operator=(const DataSink&) = delete; + DataSink(DataSink&&) = delete; + DataSink& operator=(DataSink&&) = delete; + + std::function write; + std::function done; + std::function is_writable; + }; + + using ContentProvider = + std::function; + + using ContentReceiver = + std::function; + + using MultipartContentHeader = + std::function; + + class ContentReader { + public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader muitlpart_reader) + : reader_(reader), muitlpart_reader_(muitlpart_reader) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return muitlpart_reader_(header, receiver); + } + + bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + + Reader reader_; + MultipartReader muitlpart_reader_; + }; + + using Range = std::pair; + using Ranges = std::vector; + + struct Request { + std::string method; + std::string path; + Headers headers; + std::string body; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL* ssl; +#endif + + bool has_header(const char* key) const; + std::string get_header_value(const char* key, size_t id = 0) const; + size_t get_header_value_count(const char* key) const; + void set_header(const char* key, const char* val); + void set_header(const char* key, const std::string& val); + + bool has_param(const char* key) const; + std::string get_param_value(const char* key, size_t id = 0) const; + size_t get_param_value_count(const char* key) const; + + bool is_multipart_form_data() const; + + bool has_file(const char* key) const; + MultipartFormData get_file_value(const char* key) const; + + // private members... + size_t content_length; + ContentProvider content_provider; + }; + + struct Response { + std::string version; + int status = -1; + Headers headers; + std::string body; + + bool has_header(const char* key) const; + std::string get_header_value(const char* key, size_t id = 0) const; + size_t get_header_value_count(const char* key) const; + void set_header(const char* key, const char* val); + void set_header(const char* key, const std::string& val); + + void set_redirect(const char* url); + void set_content(const char* s, size_t n, const char* content_type); + void set_content(const std::string& s, const char* content_type); + + void set_content_provider( + size_t length, + std::function + provider, + std::function resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function provider, + std::function resource_releaser = [] {}); + + Response() = default; + Response(const Response&) = default; + Response& operator=(const Response&) = default; + Response(Response&&) = default; + Response& operator=(Response&&) = default; + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); + } + } + + // private members... + size_t content_length = 0; + ContentProvider content_provider; + std::function content_provider_resource_releaser; + }; + + class Stream { + public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char* ptr, size_t size) = 0; + virtual ssize_t write(const char* ptr, size_t size) = 0; + virtual std::string get_remote_addr() const = 0; + + template + ssize_t write_format(const char* fmt, const Args&... args); + ssize_t write(const char* ptr); + ssize_t write(const std::string& s); + }; + + class TaskQueue { + public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + }; + + class ThreadPool : public TaskQueue { + public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool&) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto& t : threads_) { + t.join(); + } + } + + private: + struct worker { + explicit worker(ThreadPool& pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool& pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; + }; + + using Logger = std::function; + + class Server { + public: + using Handler = std::function; + using HandlerWithContentReader = std::function; + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server& Get(const char* pattern, Handler handler); + Server& Post(const char* pattern, Handler handler); + Server& Post(const char* pattern, HandlerWithContentReader handler); + Server& Put(const char* pattern, Handler handler); + Server& Put(const char* pattern, HandlerWithContentReader handler); + Server& Patch(const char* pattern, Handler handler); + Server& Patch(const char* pattern, HandlerWithContentReader handler); + Server& Delete(const char* pattern, Handler handler); + Server& Options(const char* pattern, Handler handler); + + [[deprecated]] bool set_base_dir(const char* dir, + const char* mount_point = nullptr); + bool set_mount_point(const char* mount_point, const char* dir); + bool remove_mount_point(const char* mount_point); + void set_file_extension_and_mimetype_mapping(const char* ext, + const char* mime); + void set_file_request_handler(Handler handler); + + void set_error_handler(Handler handler); + void set_logger(Logger logger); + + void set_expect_100_continue_handler(Expect100ContinueHandler handler); + + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + void set_payload_max_length(size_t length); + + bool bind_to_port(const char* host, int port, int socket_flags = 0); + int bind_to_any_port(const char* host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const char* host, int port, int socket_flags = 0); + + bool is_running() const; + void stop(); + + std::function new_task_queue; + + protected: + bool process_request(Stream& strm, bool last_connection, + bool& connection_close, + const std::function& setup_request); + + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t payload_max_length_; + + private: + using Handlers = std::vector>; + using HandlersForContentReader = + std::vector>; + + socket_t create_server_socket(const char* host, int port, + int socket_flags) const; + int bind_internal(const char* host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request& req, Response& res, Stream& strm, bool last_connection); + bool handle_file_request(Request& req, Response& res, bool head = false); + bool dispatch_request(Request& req, Response& res, Handlers& handlers); + bool dispatch_request_for_content_reader(Request& req, Response& res, + ContentReader content_reader, + HandlersForContentReader& handlers); + + bool parse_request_line(const char* s, Request& req); + bool write_response(Stream& strm, bool last_connection, const Request& req, + Response& res); + bool write_content_with_provider(Stream& strm, const Request& req, + Response& res, const std::string& boundary, + const std::string& content_type); + bool read_content(Stream& strm, bool last_connection, Request& req, + Response& res); + bool read_content_with_content_receiver( + Stream& strm, bool last_connection, Request& req, Response& res, + ContentReceiver receiver, MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream& strm, bool last_connection, Request& req, + Response& res, ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_; + std::atomic svr_sock_; + std::vector> base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + Expect100ContinueHandler expect_100_continue_handler_; + }; + + class Client { + public: + explicit Client(const std::string& host, int port = 80, + const std::string& client_cert_path = std::string(), + const std::string& client_key_path = std::string()); + + virtual ~Client(); + + virtual bool is_valid() const; + + std::shared_ptr Get(const char* path); + + std::shared_ptr Get(const char* path, const Headers& headers); + + std::shared_ptr Get(const char* path, Progress progress); + + std::shared_ptr Get(const char* path, const Headers& headers, + Progress progress); + + std::shared_ptr Get(const char* path, + ContentReceiver content_receiver); + + std::shared_ptr Get(const char* path, const Headers& headers, + ContentReceiver content_receiver); + + std::shared_ptr + Get(const char* path, ContentReceiver content_receiver, Progress progress); + + std::shared_ptr Get(const char* path, const Headers& headers, + ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr Get(const char* path, const Headers& headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + + std::shared_ptr Get(const char* path, const Headers& headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr Head(const char* path); + + std::shared_ptr Head(const char* path, const Headers& headers); + + std::shared_ptr Post(const char* path, const std::string& body, + const char* content_type); + + std::shared_ptr Post(const char* path, const Headers& headers, + const std::string& body, + const char* content_type); + + std::shared_ptr Post(const char* path, size_t content_length, + ContentProvider content_provider, + const char* content_type); + + std::shared_ptr Post(const char* path, const Headers& headers, + size_t content_length, + ContentProvider content_provider, + const char* content_type); + + std::shared_ptr Post(const char* path, const Params& params); + + std::shared_ptr Post(const char* path, const Headers& headers, + const Params& params); + + std::shared_ptr Post(const char* path, + const MultipartFormDataItems& items); + + std::shared_ptr Post(const char* path, const Headers& headers, + const MultipartFormDataItems& items); + + std::shared_ptr Put(const char* path, const std::string& body, + const char* content_type); + + std::shared_ptr Put(const char* path, const Headers& headers, + const std::string& body, + const char* content_type); + + std::shared_ptr Put(const char* path, size_t content_length, + ContentProvider content_provider, + const char* content_type); + + std::shared_ptr Put(const char* path, const Headers& headers, + size_t content_length, + ContentProvider content_provider, + const char* content_type); + + std::shared_ptr Put(const char* path, const Params& params); + + std::shared_ptr Put(const char* path, const Headers& headers, + const Params& params); + + std::shared_ptr Patch(const char* path, const std::string& body, + const char* content_type); + + std::shared_ptr Patch(const char* path, const Headers& headers, + const std::string& body, + const char* content_type); + + std::shared_ptr Patch(const char* path, size_t content_length, + ContentProvider content_provider, + const char* content_type); + + std::shared_ptr Patch(const char* path, const Headers& headers, + size_t content_length, + ContentProvider content_provider, + const char* content_type); + + std::shared_ptr Delete(const char* path); + + std::shared_ptr Delete(const char* path, const std::string& body, + const char* content_type); + + std::shared_ptr Delete(const char* path, const Headers& headers); + + std::shared_ptr Delete(const char* path, const Headers& headers, + const std::string& body, + const char* content_type); + + std::shared_ptr Options(const char* path); + + std::shared_ptr Options(const char* path, const Headers& headers); + + bool send(const Request& req, Response& res); + + bool send(const std::vector& requests, + std::vector& responses); + + void set_timeout_sec(time_t timeout_sec); + + void set_read_timeout(time_t sec, time_t usec); + + void set_keep_alive_max_count(size_t count); + + void set_basic_auth(const char* username, const char* password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const char* username, const char* password); +#endif + + void set_follow_location(bool on); + + void set_compress(bool on); + + void set_interface(const char* intf); + + void set_proxy(const char* host, int port); + + void set_proxy_basic_auth(const char* username, const char* password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const char* username, const char* password); +#endif + + void set_logger(Logger logger); + + protected: + bool process_request(Stream& strm, const Request& req, Response& res, + bool last_connection, bool& connection_close); + + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t timeout_sec_ = 300; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + + std::string basic_auth_username_; + std::string basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool follow_location_ = false; + + bool compress_ = false; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + + Logger logger_; + + void copy_settings(const Client& rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + timeout_sec_ = rhs.timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + keep_alive_max_count_ = rhs.keep_alive_max_count_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + follow_location_ = rhs.follow_location_; + compress_ = rhs.compress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif + logger_ = rhs.logger_; + } + + private: + socket_t create_client_socket() const; + bool read_response_line(Stream& strm, Response& res); + bool write_request(Stream& strm, const Request& req, bool last_connection); + bool redirect(const Request& req, Response& res); + bool handle_request(Stream& strm, const Request& req, Response& res, + bool last_connection, bool& connection_close); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool connect(socket_t sock, Response& res, bool& error); +#endif + + std::shared_ptr send_with_content_provider( + const char* method, const char* path, const Headers& headers, + const std::string& body, size_t content_length, + ContentProvider content_provider, const char* content_type); + + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback); + + virtual bool is_ssl() const; + }; + + inline void Get(std::vector& requests, const char* path, + const Headers& headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); + } + + inline void Get(std::vector& requests, const char* path) { + Get(requests, path, Headers()); + } + + inline void Post(std::vector& requests, const char* path, + const Headers& headers, const std::string& body, + const char* content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.headers.emplace("Content-Type", content_type); + req.body = body; + requests.emplace_back(std::move(req)); + } + + inline void Post(std::vector& requests, const char* path, + const std::string& body, const char* content_type) { + Post(requests, path, Headers(), body, content_type); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + class SSLServer : public Server { + public: + SSLServer(const char* cert_path, const char* private_key_path, + const char* client_ca_cert_file_path = nullptr, + const char* client_ca_cert_dir_path = nullptr); + + ~SSLServer() override; + + bool is_valid() const override; + + private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX* ctx_; + std::mutex ctx_mutex_; + }; + + class SSLClient : public Client { + public: + SSLClient(const std::string& host, int port = 443, + const std::string& client_cert_path = std::string(), + const std::string& client_key_path = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_path(const char* ca_ceert_file_path, + const char* ca_cert_dir_path = nullptr); + + void enable_server_certificate_verification(bool enabled); + + long get_openssl_verify_result() const; + + SSL_CTX* ssl_context() const noexcept; + + private: + bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) override; + bool is_ssl() const override; + + bool verify_host(X509* server_cert) const; + bool verify_host_with_subject_alt_name(X509* server_cert) const; + bool verify_host_with_common_name(X509* server_cert) const; + bool check_host_name(const char* pattern, size_t pattern_len) const; + + SSL_CTX* ctx_; + std::mutex ctx_mutex_; + std::vector host_components_; + + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = false; + long verify_result_ = 0; + }; +#endif + + +} // namespace httplib + +#endif // CPPHTTPLIB_HTTPLIB_H \ No newline at end of file diff --git a/web/manifest.cc b/web/manifest.cc new file mode 100644 index 0000000..0fe7b95 --- /dev/null +++ b/web/manifest.cc @@ -0,0 +1,564 @@ +#include "manifest.h" +#include "http.h" +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +typedef struct _MANIFEST_CHUNK { + char Guid[16]; + uint64_t Hash; + char ShaHash[20]; + uint8_t Group; + uint64_t Size; +} MANIFEST_CHUNK; + +typedef struct _MANIFEST_CHUNK_PART { + std::shared_ptr Chunk; + uint32_t Offset; + uint32_t Size; +} MANIFEST_CHUNK_PART; + +typedef std::unique_ptr MANIFEST_CHUNK_PARTS; +typedef struct _MANIFEST_FILE { + char FileName[128]; + char ShaHash[20]; + MANIFEST_CHUNK_PARTS ChunkParts; + uint32_t ChunkPartCount; +} MANIFEST_FILE; + +enum class EFeatureLevel : int32_t +{ + // The original version. + Original = 0, + // Support for custom fields. + CustomFields, + // Started storing the version number. + StartStoringVersion, + // Made after data files where renamed to include the hash value, these chunks now go to ChunksV2. + DataFileRenames, + // Manifest stores whether build was constructed with chunk or file data. + StoresIfChunkOrFileData, + // Manifest stores group number for each chunk/file data for reference so that external readers don't need to know how to calculate them. + StoresDataGroupNumbers, + // Added support for chunk compression, these chunks now go to ChunksV3. NB: Not File Data Compression yet. + ChunkCompressionSupport, + // Manifest stores product prerequisites info. + StoresPrerequisitesInfo, + // Manifest stores chunk download sizes. + StoresChunkFileSizes, + // Manifest can optionally be stored using UObject serialization and compressed. + StoredAsCompressedUClass, + // These two features were removed and never used. + UNUSED_0, + UNUSED_1, + // Manifest stores chunk data SHA1 hash to use in place of data compare, for faster generation. + StoresChunkDataShaHashes, + // Manifest stores Prerequisite Ids. + StoresPrerequisiteIds, + // The first minimal binary format was added. UObject classes will no longer be saved out when binary selected. + StoredAsBinaryData, + // Temporary level where manifest can reference chunks with dynamic window size, but did not serialize them. Chunks from here onwards are stored in ChunksV4. + VariableSizeChunksWithoutWindowSizeChunkInfo, + // Manifest can reference chunks with dynamic window size, and also serializes them. + VariableSizeChunks, + // Manifest stores a unique build id for exact matching of build data. + StoresUniqueBuildId, + + // !! Always after the latest version entry, signifies the latest version plus 1 to allow the following Latest alias. + LatestPlusOne, + // An alias for the actual latest version value. + Latest = (LatestPlusOne - 1), + // An alias to provide the latest version of a manifest supported by file data (nochunks). + LatestNoChunks = StoresChunkFileSizes, + // An alias to provide the latest version of a manifest supported by a json serialized format. + LatestJson = StoresPrerequisiteIds, + // An alias to provide the first available version of optimised delta manifest saving. + FirstOptimisedDelta = StoresUniqueBuildId, + + // JSON manifests were stored with a version of 255 during a certain CL range due to a bug. + // We will treat this as being StoresChunkFileSizes in code. + BrokenJsonVersion = 255, + // This is for UObject default, so that we always serialize it. + Invalid = -1 +}; + +typedef std::unique_ptr MANIFEST_FILE_LIST; +typedef std::unique_ptr[]> MANIFEST_CHUNK_LIST; +auto hash = [](const char* n) { return (*((uint64_t*)n)) ^ (*(((uint64_t*)n) + 1)); }; +auto equal = [](const char* a, const char* b) {return !memcmp(a, b, 16); }; +typedef std::unordered_map MANIFEST_CHUNK_LOOKUP; +typedef struct _MANIFEST { + EFeatureLevel FeatureLevel; + bool bIsFileData; + uint32_t AppID; + std::string AppName; + std::string BuildVersion; + std::string LaunchExe; + std::string LaunchCommand; + //std::set PrereqIds; + //std::string PrereqName; + //std::string PrereqPath; + //std::string PrereqArgs; + MANIFEST_FILE_LIST FileManifestList; + uint32_t FileCount; + MANIFEST_CHUNK_LIST ChunkManifestList; + uint32_t ChunkCount; + + std::string CloudDirHost; + std::string CloudDirPath; +} MANIFEST; + +inline const uint8_t GetByteValue(const char Char) +{ + if (Char >= '0' && Char <= '9') + { + return Char - '0'; + } + else if (Char >= 'A' && Char <= 'F') + { + return (Char - 'A') + 10; + } + return (Char - 'a') + 10; +} + +inline int HexToBytes(const char* hex, char* output) { + int NumBytes = 0; + while (*hex) + { + output[NumBytes] = GetByteValue(*hex++) << 4; + output[NumBytes] += GetByteValue(*hex++); + ++NumBytes; + } + return NumBytes; +} + +inline void urlencode(const const char* s, std::ostringstream& e) +{ + static const char lookup[] = "0123456789abcdef"; + for (int i = 0, ix = strlen(s); i < ix; i++) + { + const char& c = s[i]; + if ((48 <= c && c <= 57) ||//0-9 + (65 <= c && c <= 90) ||//abc...xyz + (97 <= c && c <= 122) || //ABC...XYZ + (c == '-' || c == '_' || c == '.' || c == '~') + ) + { + e << c; + } + else + { + e << '%'; + e << lookup[(c & 0xF0) >> 4]; + e << lookup[(c & 0x0F)]; + } + } +} + +inline void GetDownloadFile(rapidjson::Value& manifest, std::string* id) { + auto uriStr = std::string(manifest["uri"].GetString()); + *id = std::string(uriStr.begin() + uriStr.find_last_of('/') + 1, uriStr.end()); +} + +inline void GetDownloadUrl(rapidjson::Value& manifest, std::string* host, std::string* uri) { + rapidjson::Value& uri_val = manifest["uri"]; + Uri url_v = Uri::Parse(uri_val.GetString()); + *host = url_v.Host; + if (!manifest.HasMember("queryParams")) { + *uri = url_v.Path; + } + else { + rapidjson::Value& queryParams = manifest["queryParams"]; + std::ostringstream oss; + oss << url_v.Path << "?"; + for (auto& itr : queryParams.GetArray()) { + urlencode(itr["name"].GetString(), oss); + oss << "="; + urlencode(itr["value"].GetString(), oss); + oss << "&"; + } + oss.seekp(-1, std::ios_base::end); // remove last & + oss << '\0'; + *uri = oss.str(); + } +} + +inline int random(int min, int max) //range : [min, max) +{ + static bool first = true; + if (first) + { + srand(time(NULL)); //seeding for the first time only! + first = false; + } + return min + rand() % ((max + 1) - min); +} + +inline const void HashToBytes(const char* data, char* output) { + int size = strlen(data) / 3; + char buf[4]; + buf[3] = '\0'; + for (int i = 0; i < size; i++) + { + buf[0] = data[i * 3]; + buf[1] = data[i * 3 + 1]; + buf[2] = data[i * 3 + 2]; + output[i] = atoi(buf); + } +} + +bool ManifestGrab(const char* elementsResponse, fs::path CacheFolder, MANIFEST** PManifest) { + std::string host, url; + fs::path cachePath; + { + rapidjson::Document elements; + elements.Parse(elementsResponse); + if (elements.HasParseError()) { + printf("%d %zu\n", elements.GetParseError(), elements.GetErrorOffset()); + } + rapidjson::Value& v = elements["elements"].GetArray()[0]; + + //rapidjson::Value& hash = v["hash"]; (hash is unused atm) + //char* ManifestHash = new char[hash.GetStringLength() / 2]; + //HexToBytes(hash.GetString(), ManifestHash); + + rapidjson::Value& manifest = v["manifests"][random(0, v["manifests"].GetArray().Size() - 1)]; + GetDownloadUrl(manifest, &host, &url); + std::string id; + GetDownloadFile(manifest, &id); + if (!CacheFolder.empty()) { + cachePath = CacheFolder / id; + } + } + httplib::Client c(host); + + rapidjson::Document manifestDoc; + printf("getting manifest\n"); + if (!CacheFolder.empty() && fs::status(cachePath).type() == fs::file_type::regular) { + auto fp = fopen(cachePath.string().c_str(), "rb"); + fseek(fp, 0, SEEK_END); + long manifestSize = ftell(fp); + rewind(fp); + auto manifestStr = new char[manifestSize + 1]; + fread(manifestStr, 1, manifestSize, fp); + fclose(fp); + manifestStr[manifestSize] = '\0'; + manifestDoc.Parse(manifestStr); + delete[] manifestStr; + } + else { + bool manifestStrReserved = false; + std::vector manifestStr; + manifestStr.reserve(8192); + c.Get(url.c_str(), + [&](const char* data, uint64_t data_length) { + manifestStr.insert(manifestStr.end(), data, data + data_length); + return true; + }, + [&](uint64_t len, uint64_t total) { + if (!manifestStrReserved) { + manifestStr.reserve(total); + printf("reserved for %llu\n", total); + manifestStrReserved = true; + } + static int n = 0; + if (!(n++ % 400)) { + printf("\r%lld / %lld bytes => %.1f%% complete", + len, total, + float(len * 100) / total); + } + return true; + }); + + if (!CacheFolder.empty()) { + auto fp = fopen(cachePath.string().c_str(), "wb"); + fwrite(manifestStr.data(), 1, manifestStr.size(), fp); + fclose(fp); + } + + manifestStr.push_back('\0'); + manifestDoc.Parse(manifestStr.data()); + } + printf("\n"); + printf("parsing\n"); + printf("parsed\n"); + + if (manifestDoc.HasParseError()) { + printf("JSON Parse Error %d @ %zu\n", manifestDoc.GetParseError(), manifestDoc.GetErrorOffset()); + } + + MANIFEST* Manifest = new MANIFEST(); + HashToBytes(manifestDoc["ManifestFileVersion"].GetString(), (char*)&Manifest->FeatureLevel); + Manifest->bIsFileData = manifestDoc["bIsFileData"].GetBool(); + HashToBytes(manifestDoc["AppID"].GetString(), (char*)&Manifest->AppID); + Manifest->AppName = manifestDoc["AppNameString"].GetString(); + Manifest->BuildVersion = manifestDoc["BuildVersionString"].GetString(); + Manifest->LaunchExe = manifestDoc["LaunchExeString"].GetString(); + Manifest->LaunchCommand = manifestDoc["LaunchCommand"].GetString(); + + Manifest->CloudDirHost = host; +#define CHUNK_DIR(dir) "/Chunks" dir "/" + const char* ChunksDir; + if (Manifest->FeatureLevel < EFeatureLevel::DataFileRenames) { + ChunksDir = CHUNK_DIR(""); + } + else if (Manifest->FeatureLevel < EFeatureLevel::ChunkCompressionSupport) { + ChunksDir = CHUNK_DIR("V2"); + } + else if (Manifest->FeatureLevel < EFeatureLevel::VariableSizeChunksWithoutWindowSizeChunkInfo) { + ChunksDir = CHUNK_DIR("V3"); + } + else { + ChunksDir = CHUNK_DIR("V4"); + } +#undef CHUNK_DIR + Manifest->CloudDirPath = url.substr(0, url.find_last_of('/')) + ChunksDir; + + MANIFEST_CHUNK_LOOKUP ChunkManifestLookup; // used to speed up lookups instead of doing a linear search over everything + { + rapidjson::Value& HashList = manifestDoc["ChunkHashList"]; + rapidjson::Value& ShaList = manifestDoc["ChunkShaList"]; + rapidjson::Value& GroupList = manifestDoc["DataGroupList"]; + rapidjson::Value& SizeList = manifestDoc["ChunkFilesizeList"]; + + Manifest->ChunkCount = HashList.MemberCount(); + Manifest->ChunkManifestList = std::make_unique[]>(Manifest->ChunkCount); + ChunkManifestLookup.reserve(Manifest->ChunkCount); + + printf("%d chunks\n", Manifest->ChunkCount); + int i = 0; + for (rapidjson::Value::ConstMemberIterator hashItr = HashList.MemberBegin(), shaItr = ShaList.MemberBegin(), groupItr = GroupList.MemberBegin(), sizeItr = SizeList.MemberBegin(); + i != Manifest->ChunkCount; ++i, ++hashItr, ++shaItr, ++groupItr, ++sizeItr) + { + auto chunk = std::make_shared(); + HexToBytes(hashItr->name.GetString(), chunk->Guid); + HashToBytes(hashItr->value.GetString(), (char*)&chunk->Hash); + HashToBytes(sizeItr->value.GetString(), (char*)&chunk->Size); + HexToBytes(shaItr->value.GetString(), (char*)&chunk->ShaHash); + chunk->Group = atoi(groupItr->value.GetString()); + + Manifest->ChunkManifestList[i] = chunk; + ChunkManifestLookup[chunk->Guid] = i; + } + } + + { + rapidjson::Value& FileList = manifestDoc["FileManifestList"]; + Manifest->FileCount = FileList.Size(); + Manifest->FileManifestList = std::make_unique(Manifest->FileCount); + + int i = 0; + for (auto& fileManifest : FileList.GetArray()) { + auto& file = Manifest->FileManifestList[i++]; + strcpy(file.FileName, fileManifest["Filename"].GetString()); + HashToBytes(fileManifest["FileHash"].GetString(), (char*)&file.ShaHash); + file.ChunkPartCount = fileManifest["FileChunkParts"].Size(); + file.ChunkParts = std::make_unique(file.ChunkPartCount); + + int j = 0; + for (auto& fileChunk : fileManifest["FileChunkParts"].GetArray()) { + MANIFEST_CHUNK_PART part; + char guidBuffer[16]; + HexToBytes(fileChunk["Guid"].GetString(), guidBuffer); + part.Chunk = Manifest->ChunkManifestList[ChunkManifestLookup[guidBuffer]]; + HashToBytes(fileChunk["Offset"].GetString(), (char*)&part.Offset); + HashToBytes(fileChunk["Size"].GetString(), (char*)&part.Size); + file.ChunkParts[j++] = part; + } + } + } + + *PManifest = Manifest; + return true; +} + +uint64_t ManifestDownloadSize(MANIFEST* Manifest) { + return std::accumulate(Manifest->ChunkManifestList.get(), Manifest->ChunkManifestList.get() + Manifest->ChunkCount, 0ull, + [](uint64_t sum, const std::shared_ptr& curr) { + return sum + curr->Size; + }); +} + +uint64_t ManifestInstallSize(MANIFEST* Manifest) { + return std::accumulate(Manifest->FileManifestList.get(), Manifest->FileManifestList.get() + Manifest->FileCount, 0ull, + [](uint64_t sum, MANIFEST_FILE& file) { + return std::accumulate(file.ChunkParts.get(), file.ChunkParts.get() + file.ChunkPartCount, sum, + [](uint64_t sum, MANIFEST_CHUNK_PART& part) { + return sum + part.Size; + }); + }); +} + +void ManifestGetFiles(MANIFEST* Manifest, MANIFEST_FILE** PFileList, uint32_t* PFileCount, uint16_t* PStrideSize) { + *PFileList = Manifest->FileManifestList.get(); + *PFileCount = Manifest->FileCount; + *PStrideSize = sizeof(MANIFEST_FILE); +} + +void ManifestGetChunks(MANIFEST* Manifest, std::shared_ptr** PChunkList, uint32_t* PChunkCount) { + *PChunkList = Manifest->ChunkManifestList.get(); + *PChunkCount = Manifest->ChunkCount; +} + +MANIFEST_FILE* ManifestGetFile(MANIFEST* Manifest, const char* Filename) { + for (int i = 0; i < Manifest->FileCount; ++i) { + if (!strcmp(Filename, Manifest->FileManifestList[i].FileName)) { + return &Manifest->FileManifestList[i]; + } + } + return nullptr; +} + +void ManifestGetCloudDir(MANIFEST* Manifest, char* CloudDirHostBuffer, char* CloudDirPathBuffer) { + strcpy(CloudDirHostBuffer, Manifest->CloudDirHost.c_str()); + strcpy(CloudDirPathBuffer, Manifest->CloudDirPath.c_str()); +} + +void ManifestGetLaunchInfo(MANIFEST* Manifest, char* ExeBuffer, char* CommandBuffer) { + strcpy(ExeBuffer, Manifest->LaunchExe.c_str()); + strcpy(CommandBuffer, Manifest->LaunchCommand.c_str()); +} + +void ManifestDelete(MANIFEST* Manifest) { + delete Manifest; +} + +typedef struct _MANIFEST_AUTH { + std::string AccessToken; + time_t ExpiresAt; +} MANIFEST_AUTH; + +inline int ParseInt(const char* value) +{ + return std::strtol(value, nullptr, 10); +} + +void UpdateManifest(MANIFEST_AUTH* Auth) { + httplib::SSLClient client("account-public-service-prod03.ol.epicgames.com"); + static const httplib::Headers headers = { + { "Authorization", "basic MzhkYmZjMzE5NjAyNGQ1OTgwMzg2YTM3YjdjNzkyYmI6YTYyODBiODctZTQ1ZS00MDliLTk2ODEtOGYxNWViN2RiY2Y1" } + }; + static const httplib::Params params = { + { "grant_type", "client_credentials" } + }; + rapidjson::Document d; + d.Parse(client.Post("/account/api/oauth/token", headers, params)->body.c_str()); + + rapidjson::Value& token = d["access_token"]; + Auth->AccessToken = token.GetString(); + + rapidjson::Value& expires_at = d["expires_at"]; + auto expires_str = expires_at.GetString(); + constexpr const size_t expectedLength = sizeof("YYYY-MM-DDTHH:MM:SSZ") - 1; + static_assert(expectedLength == 20, "Unexpected ISO 8601 date/time length"); + + if (expires_at.GetStringLength() < expectedLength) + { + return; + } + + std::tm time = { 0 }; + time.tm_year = ParseInt(&expires_str[0]) - 1900; + time.tm_mon = ParseInt(&expires_str[5]) - 1; + time.tm_mday = ParseInt(&expires_str[8]); + time.tm_hour = ParseInt(&expires_str[11]); + time.tm_min = ParseInt(&expires_str[14]); + time.tm_sec = ParseInt(&expires_str[17]); + time.tm_isdst = 0; + const int millis = expires_at.GetStringLength() > 20 ? ParseInt(&expires_str[20]) : 0; + Auth->ExpiresAt = std::mktime(&time) * 1000 + millis; +} + +bool ManifestAuthGrab(MANIFEST_AUTH** PManifestAuth) { + MANIFEST_AUTH* Auth = new MANIFEST_AUTH; + UpdateManifest(Auth); + *PManifestAuth = Auth; + return true; +} + +void ManifestAuthDelete(MANIFEST_AUTH* ManifestAuth) { + delete ManifestAuth; +} + +bool ManifestAuthGetManifest(MANIFEST_AUTH* ManifestAuth, fs::path CachePath, MANIFEST** PManifest) { + if (ManifestAuth->ExpiresAt < time(NULL)) { + UpdateManifest(ManifestAuth); + } + + httplib::SSLClient client("launcher-public-service-prod-m.ol.epicgames.com"); + char* authHeader = new char[7 + ManifestAuth->AccessToken.size() + 1]; + sprintf(authHeader, "bearer %s", ManifestAuth->AccessToken.c_str()); + + httplib::Headers headers = { + { "Authorization", authHeader } + }; + ManifestGrab(client.Post("/launcher/api/public/assets/v2/platform/Windows/catalogItem/4fe75bbc5a674f4f9b356b5c90567da5/app/Fortnite/label/Live/", headers, "", "application/json")->body.c_str(), CachePath, PManifest); + delete[] authHeader; + return true; +} + +void ManifestFileGetName(MANIFEST_FILE* File, char* FilenameBuffer) { + strcpy(FilenameBuffer, File->FileName); +} + +void ManifestFileGetChunks(MANIFEST_FILE* File, MANIFEST_CHUNK_PART** PChunkPartList, uint32_t* PChunkPartCount, uint16_t* PStrideSize) { + *PChunkPartList = File->ChunkParts.get(); + *PChunkPartCount = File->ChunkPartCount; + *PStrideSize = sizeof(MANIFEST_CHUNK_PART); +} + +bool ManifestFileGetChunkIndex(MANIFEST_FILE* File, uint64_t Offset, uint32_t* ChunkIndex, uint32_t* ChunkOffset) { + for (int i = 0; i < File->ChunkPartCount; ++i) { + if (Offset < File->ChunkParts[i].Size) { + *ChunkIndex = i; + *ChunkOffset = Offset; + return true; + } + Offset -= File->ChunkParts[i].Size; + } + return false; +} + +uint64_t ManifestFileGetFileSize(MANIFEST_FILE* File) { + return std::accumulate(File->ChunkParts.get(), File->ChunkParts.get() + File->ChunkPartCount, 0ull, + [](uint64_t sum, const MANIFEST_CHUNK_PART& curr) { + return sum + curr.Size; + }); +} + +char* ManifestFileGetSha1(MANIFEST_FILE* File) { + return File->ShaHash; +} + +MANIFEST_CHUNK* ManifestFileChunkGetChunk(MANIFEST_CHUNK_PART* ChunkPart) { + return ChunkPart->Chunk.get(); +} + +void ManifestFileChunkGetData(MANIFEST_CHUNK_PART* ChunkPart, uint32_t* POffset, uint32_t* PSize) { + *POffset = ChunkPart->Offset; + *PSize = ChunkPart->Size; +} + +char* ManifestChunkGetGuid(MANIFEST_CHUNK* Chunk) { + return Chunk->Guid; +} + +char* ManifestChunkGetSha1(MANIFEST_CHUNK* Chunk) { + return Chunk->ShaHash; +} + +// Example input UrlBuffer: https://epicgames-download1.akamaized.net/Builds/Fortnite/CloudDir/ChunksV3/ +// Make sure UrlBuffer has some extra space in front as well +void ManifestChunkAppendUrl(MANIFEST_CHUNK* Chunk, char* UrlBuffer) { + sprintf(UrlBuffer, "%s%02d/%016llX_%016llX%016llX.chunk", UrlBuffer, Chunk->Group, Chunk->Hash, ntohll(*(uint64_t*)Chunk->Guid), ntohll(*(uint64_t*)(Chunk->Guid + 8))); +} \ No newline at end of file diff --git a/web/manifest.h b/web/manifest.h new file mode 100644 index 0000000..e6ae8cd --- /dev/null +++ b/web/manifest.h @@ -0,0 +1,47 @@ +#pragma once +#include +#include +#include +namespace fs = std::filesystem; + +// Grabbing and parsing manifest + +typedef struct _MANIFEST MANIFEST; +typedef struct _MANIFEST_FILE MANIFEST_FILE; +typedef struct _MANIFEST_CHUNK MANIFEST_CHUNK; +typedef struct _MANIFEST_CHUNK_PART MANIFEST_CHUNK_PART; + +uint64_t ManifestDownloadSize(MANIFEST* Manifest); +uint64_t ManifestInstallSize(MANIFEST* Manifest); +void ManifestGetFiles(MANIFEST* Manifest, MANIFEST_FILE** PFileList, uint32_t* PFileCount, uint16_t* PStrideSize); +void ManifestGetChunks(MANIFEST* Manifest, std::shared_ptr** PChunkList, uint32_t* PChunkCount); +MANIFEST_FILE* ManifestGetFile(MANIFEST* Manifest, const char* Filename); +void ManifestGetCloudDir(MANIFEST* Manifest, char* CloudDirHostBuffer, char* CloudDirPathBuffer); +void ManifestGetLaunchInfo(MANIFEST* Manifest, char* ExeBuffer, char* CommandBuffer); +void ManifestDelete(MANIFEST* Manifest); + +// Authentication + +typedef struct _MANIFEST_AUTH MANIFEST_AUTH; +bool ManifestAuthGrab(MANIFEST_AUTH** PManifestAuth); +void ManifestAuthDelete(MANIFEST_AUTH* ManifestAuth); +bool ManifestAuthGetManifest(MANIFEST_AUTH* ManifestAuth, fs::path CachePath, MANIFEST** PManifest); + +// Files + +void ManifestFileGetName(MANIFEST_FILE* File, char* FilenameBuffer); +void ManifestFileGetChunks(MANIFEST_FILE* File, MANIFEST_CHUNK_PART** PChunkPartList, uint32_t* PChunkPartCount, uint16_t* PStrideSize); +bool ManifestFileGetChunkIndex(MANIFEST_FILE* File, uint64_t Offset, uint32_t* ChunkIndex, uint32_t* ChunkOffset); +uint64_t ManifestFileGetFileSize(MANIFEST_FILE* File); +char* ManifestFileGetSha1(MANIFEST_FILE* File); + +// File Chunks + +MANIFEST_CHUNK* ManifestFileChunkGetChunk(MANIFEST_CHUNK_PART* ChunkPart); +void ManifestFileChunkGetData(MANIFEST_CHUNK_PART* ChunkPart, uint32_t* POffset, uint32_t* PSize); + +// Chunks + +char* ManifestChunkGetGuid(MANIFEST_CHUNK* Chunk); +char* ManifestChunkGetSha1(MANIFEST_CHUNK* Chunk); +void ManifestChunkAppendUrl(MANIFEST_CHUNK* Chunk, char* UrlBuffer); diff --git a/web/url.hh b/web/url.hh new file mode 100644 index 0000000..ad87a6b --- /dev/null +++ b/web/url.hh @@ -0,0 +1,70 @@ +#include +#include // find + +struct Uri +{ +public: + std::string QueryString, Path, Protocol, Host, Port; + + static Uri Parse(const std::string& uri) + { + Uri result; + + typedef std::string::const_iterator iterator_t; + + if (uri.length() == 0) + return result; + + iterator_t uriEnd = uri.end(); + + // get query start + iterator_t queryStart = std::find(uri.begin(), uriEnd, '?'); + + // protocol + iterator_t protocolStart = uri.begin(); + iterator_t protocolEnd = std::find(protocolStart, uriEnd, ':'); //"://"); + + if (protocolEnd != uriEnd) + { + std::string prot = &*(protocolEnd); + if ((prot.length() > 3) && (prot.substr(0, 3) == "://")) + { + result.Protocol = std::string(protocolStart, protocolEnd); + protocolEnd += 3; // :// + } + else + protocolEnd = uri.begin(); // no protocol + } + else + protocolEnd = uri.begin(); // no protocol + + // host + iterator_t hostStart = protocolEnd; + iterator_t pathStart = std::find(hostStart, uriEnd, '/'); // get pathStart + + iterator_t hostEnd = std::find(protocolEnd, + (pathStart != uriEnd) ? pathStart : queryStart, + ':'); // check for port + + result.Host = std::string(hostStart, hostEnd); + + // port + if ((hostEnd != uriEnd) && ((&*(hostEnd))[0] == ':')) // we have a port + { + hostEnd++; + iterator_t portEnd = (pathStart != uriEnd) ? pathStart : queryStart; + result.Port = std::string(hostEnd, portEnd); + } + + // path + if (pathStart != uriEnd) + result.Path = std::string(pathStart, queryStart); + + // query + if (queryStart != uriEnd) + result.QueryString = std::string(queryStart, uri.end()); + + return result; + + } // Parse +}; // uri \ No newline at end of file diff --git a/winfspcheck.cpp b/winfspcheck.cpp new file mode 100644 index 0000000..c3c3ad8 --- /dev/null +++ b/winfspcheck.cpp @@ -0,0 +1,60 @@ +#include "winfspcheck.h" + +#define WIN32_LEAN_AND_MEAN +#include +#include + +#include +namespace fs = std::filesystem; + + +bool alreadyLoaded = false; + +WinFspCheckResult LoadWinFsp() { + if (alreadyLoaded) { + return WinFspCheckResult::LOADED; + } + + DWORD driverByteCount; + if (EnumDeviceDrivers(NULL, 0, &driverByteCount)) { + + DWORD driverCount = driverByteCount / sizeof(LPVOID); + LPVOID* drivers = new LPVOID[driverCount]; + WinFspCheckResult result = WinFspCheckResult::NOT_FOUND; + + if (EnumDeviceDrivers(drivers, driverByteCount, &driverByteCount)) { + char driverFilename[MAX_PATH]; + + for (int i = 0; i < driverCount; ++i) { + if (GetDeviceDriverFileNameA(drivers[i], driverFilename, MAX_PATH) && strstr(driverFilename, "winfsp")) + { + + fs::path dll; + if (strncmp(driverFilename, "\\??\\", 4) == 0) { + dll = driverFilename + 4; + } + else { + dll = driverFilename; + } + dll = dll.replace_extension(".dll"); + + if (fs::status(dll).type() != fs::file_type::regular) { + result = WinFspCheckResult::NO_DLL; + continue; + } + if (LoadLibraryA(dll.string().c_str()) == NULL) { + result = WinFspCheckResult::CANNOT_LOAD; + continue; + } + alreadyLoaded = true; + result = WinFspCheckResult::LOADED; + break; + } + } + } + + delete[] drivers; + return result; + } + return WinFspCheckResult::CANNOT_ENUMERATE; +} \ No newline at end of file diff --git a/winfspcheck.h b/winfspcheck.h new file mode 100644 index 0000000..9e82015 --- /dev/null +++ b/winfspcheck.h @@ -0,0 +1,11 @@ +#pragma once + +enum class WinFspCheckResult { + LOADED = 0, // successfully loaded + CANNOT_ENUMERATE = 1, // cannot enumerate over drivers + NOT_FOUND = 2, // cannot find driver + NO_DLL = 3, // no dll in driver folder + CANNOT_LOAD = 4, // cannot load dll +}; + +WinFspCheckResult LoadWinFsp(); \ No newline at end of file