Skip to content

Commit

Permalink
Support prepared statements and parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
vhiairrassary committed Dec 17, 2024
1 parent bd381c1 commit f803e12
Show file tree
Hide file tree
Showing 18 changed files with 823 additions and 326 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ include_directories(
# Embed ./src/assets/index.html as a C++ header
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp playgroundContent
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/httpserver_extension/http_handler/playground.hpp playgroundContent
DEPENDS ${PROJECT_SOURCE_DIR}/src/assets/index.html
)

set(EXTENSION_SOURCES
src/httpserver_extension.cpp
src/http_handler/authentication.cpp
src/http_handler/bindings.cpp
src/http_handler/handler.cpp
src/http_handler/response_serializer.cpp
${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
)

Expand Down
55 changes: 55 additions & 0 deletions src/http_handler/authentication.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "httpserver_extension/http_handler/common.hpp"
#include "httpserver_extension/state.hpp"
#include <string>
#include <vector>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

namespace duckdb_httpserver {

// Base64 decoding function
static std::string base64_decode(const std::string &in) {
std::string out;
std::vector<int> T(256, -1);
for (int i = 0; i < 64; i++)
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;

int val = 0, valb = -8;
for (unsigned char c : in) {
if (T[c] == -1) break;
val = (val << 6) + T[c];
valb += 6;
if (valb >= 0) {
out.push_back(char((val >> valb) & 0xFF));
valb -= 8;
}
}
return out;
}

// Check authentication
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
if (global_state.auth_token.empty()) {
return; // No authentication required if no token is set
}

// Check for X-API-Key header
auto api_key = req.get_header_value("X-API-Key");
if (!api_key.empty() && api_key == global_state.auth_token) {
return;
}

// Check for Basic Auth
auto auth = req.get_header_value("Authorization");
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
std::string decoded_auth = base64_decode(auth.substr(6));
if (decoded_auth == global_state.auth_token) {
return;
}
}

throw HttpHandlerException(401, "Unauthorized");
}

} // namespace duckdb_httpserver
91 changes: 91 additions & 0 deletions src/http_handler/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "httpserver_extension/http_handler/common.hpp"
#include "duckdb.hpp"

#include "yyjson.hpp"
#include <string>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

using namespace duckdb;
using namespace duckdb_yyjson;

namespace duckdb_httpserver {

static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
if (!yyjson_is_obj(parameterVal)) {
throw HttpHandlerException(400, "The parameter `" + key + "` parameter must be an object");
}

auto typeVal = yyjson_obj_get(parameterVal, "type");
if (!typeVal) {
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field");
}
if (!yyjson_is_str(typeVal)) {
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string");
}
auto type = std::string(yyjson_get_str(typeVal));

auto valueVal = yyjson_obj_get(parameterVal, "value");
if (!valueVal) {
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field");
}

if (type == "TEXT") {
if (!yyjson_is_str(valueVal)) {
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string");
}

return BoundParameterData(Value(yyjson_get_str(valueVal)));
}
else if (type == "BOOLEAN") {
if (!yyjson_is_bool(valueVal)) {
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
}

return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
}

throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
}

static case_insensitive_map_t<BoundParameterData> ExtractQueryParametersImpl(yyjson_doc* parametersDoc) {
if (!parametersDoc) {
throw HttpHandlerException(400, "Unable to parse the `parameters` parameter");
}

auto parametersRoot = yyjson_doc_get_root(parametersDoc);
if (!yyjson_is_obj(parametersRoot)) {
throw HttpHandlerException(400, "The `parameters` parameter must be an object");
}

case_insensitive_map_t<BoundParameterData> named_values;

size_t idx, max;
yyjson_val *parameterKeyVal, *parameterVal;
yyjson_obj_foreach(parametersRoot, idx, max, parameterKeyVal, parameterVal) {
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));

named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
}

return named_values;
}

case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(const duckdb_httplib_openssl::Request& req) {
yyjson_doc *parametersDoc = nullptr;

try {
auto parametersJson = req.get_param_value("parameters");
auto parametersJsonCStr = parametersJson.c_str();
parametersDoc = yyjson_read(parametersJsonCStr, strlen(parametersJsonCStr), 0);
return ExtractQueryParametersImpl(parametersDoc);
}
catch (const Exception& exception) {
yyjson_doc_free(parametersDoc);

throw exception;
}
}

} // namespace duckdb_httpserver
150 changes: 150 additions & 0 deletions src/http_handler/handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include "httpserver_extension/http_handler/authentication.hpp"
#include "httpserver_extension/http_handler/bindings.hpp"
#include "httpserver_extension/http_handler/common.hpp"
#include "httpserver_extension/http_handler/handler.hpp"
#include "httpserver_extension/http_handler/playground.hpp"
#include "httpserver_extension/http_handler/response_serializer.hpp"
#include "httpserver_extension/state.hpp"
#include "duckdb.hpp"

#include <string>
#include <vector>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

using namespace duckdb;

namespace duckdb_httpserver {

// Handle both GET and POST requests
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
try {
CheckAuthentication(req);

// CORS allow
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Max-Age", "86400");

// Handle preflight OPTIONS request
if (req.method == "OPTIONS") {
res.status = 204; // No content
return;
}

auto query = ExtractQuery(req);
auto format = ExtractFormat(req);

if (query == "") {
res.status = 200;
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html");
return;
}

if (!global_state.db_instance) {
throw IOException("Database instance not initialized");
}

auto start = std::chrono::system_clock::now();
auto result = ExecuteQuery(req, query);
auto end = std::chrono::system_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);

QueryExecStats stats{
static_cast<float>(elapsed.count()) / 1000,
0,
0
};

// Format Options
if (format == "JSONEachRow") {
std::string json_output = ConvertResultToNDJSON(*result);
res.set_content(json_output, "application/x-ndjson");
} else if (format == "JSONCompact") {
std::string json_output = ConvertResultToJSON(*result, stats);
res.set_content(json_output, "application/json");
} else {
// Default to NDJSON for DuckDB's own queries
std::string json_output = ConvertResultToNDJSON(*result);
res.set_content(json_output, "application/x-ndjson");
}

}
catch (const HttpHandlerException& ex) {
res.status = ex.status;
res.set_content(ex.message, "text/plain");
}
catch (const Exception& ex) {
res.status = 500;
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
res.set_content(error_message, "text/plain");
}
}

// Execute query (optionally using a prepared statement)
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
const duckdb_httplib_openssl::Request& req,
const std::string& query
) {
Connection con(*global_state.db_instance);
std::unique_ptr<MaterializedQueryResult> result;

if (req.has_param("parameters")) {
auto prepared_stmt = con.Prepare(query);
if (prepared_stmt->HasError()) {
throw HttpHandlerException(500, prepared_stmt->GetError());
}

auto named_values = ExtractQueryParameters(req);

auto prepared_stmt_result = prepared_stmt->Execute(named_values);
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
} else {
result = con.Query(query);
}

if (result->HasError()) {
throw HttpHandlerException(500, result->GetError());
}

return result;
}

std::string ExtractQuery(const duckdb_httplib_openssl::Request& req) {
// Check if the query is in the URL parameters
if (req.has_param("query")) {
return req.get_param_value("query");
}
else if (req.has_param("q")) {
return req.get_param_value("q");
}

// If not in URL, and it's a POST request, check the body
else if (req.method == "POST" && !req.body.empty()) {
return req.body;
}

// std::optional is not available for this project
return "";
}

std::string ExtractFormat(const duckdb_httplib_openssl::Request& req) {
std::string format = "JSONEachRow";

// Check for format in URL parameter or header
if (req.has_param("default_format")) {
format = req.get_param_value("default_format");
} else if (req.has_header("X-ClickHouse-Format")) {
format = req.get_header_value("X-ClickHouse-Format");
} else if (req.has_header("format")) {
format = req.get_header_value("format");
}

return format;
}

} // namespace duckdb_httpserver
Loading

0 comments on commit f803e12

Please sign in to comment.