Skip to content

Commit

Permalink
Support prepared statements and parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
vhiairrassary committed Dec 16, 2024
1 parent bd381c1 commit 43152cb
Showing 1 changed file with 154 additions and 46 deletions.
200 changes: 154 additions & 46 deletions src/httpserver_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ struct HttpServerState {

static HttpServerState global_state;

struct HttpServerException: public std::exception {
int status;
std::string message;

HttpServerException(int status, const std::string& message) : message(message), status(status) {}
};

std::string GetColumnType(MaterializedQueryResult &result, idx_t column) {
if (result.RowCount() == 0) {
return "String";
Expand Down Expand Up @@ -152,28 +159,28 @@ std::string base64_decode(const std::string &in) {
return out;
}

// Auth Check
bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) {
// Check authentication
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
if (global_state.auth_token.empty()) {
return true; // No authentication required if no token is set
return; // No authentication required if no token is set
}

// Check for X-API-Key header
auto api_key = req.get_header_value("X-API-Key");
if (!api_key.empty() && api_key == global_state.auth_token) {
return true;
return;
}

// Check for Basic Auth
auto auth = req.get_header_value("Authorization");
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
std::string decoded_auth = base64_decode(auth.substr(6));
if (decoded_auth == global_state.auth_token) {
return true;
return;
}
}

return false;
throw HttpServerException(401, "Unauthorized");
}

// Convert the query result to NDJSON (JSONEachRow) format
Expand Down Expand Up @@ -217,49 +224,130 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
return ndjson_output;
}

// Handle both GET and POST requests
void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
std::string query;
BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
if (!yyjson_is_obj(parameterVal)) {
throw HttpServerException(400, "The parameter `" + key + "` parameter must be an object");
}

// Check authentication
if (!IsAuthenticated(req)) {
res.status = 401;
res.set_content("Unauthorized", "text/plain");
return;
auto typeVal = yyjson_obj_get(parameterVal, "type");
if (!typeVal) {
throw HttpServerException(400, "The parameter `" + key + "` does not have a `type` field");
}
if (!yyjson_is_str(typeVal)) {
throw HttpServerException(400, "The field `type` for the parameter `" + key + "` must be a string");
}
auto type = std::string(yyjson_get_str(typeVal));

// CORS allow
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Max-Age", "86400");
auto valueVal = yyjson_obj_get(parameterVal, "value");
if (!valueVal) {
throw HttpServerException(400, "The parameter `" + key + "` does not have a `value` field");
}

// Handle preflight OPTIONS request
if (req.method == "OPTIONS") {
res.status = 204; // No content
return;
if (type == "TEXT") {
if (!yyjson_is_str(valueVal)) {
throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a string");
}

return BoundParameterData(Value(yyjson_get_str(valueVal)));
}
else if (type == "BOOLEAN") {
if (!yyjson_is_bool(valueVal)) {
throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
}

return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
}

throw HttpServerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
}

case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_doc* parametersDoc) {
if (!parametersDoc) {
throw HttpServerException(400, "Unable to parse the `parameters` parameter");
}

auto parametersRoot = yyjson_doc_get_root(parametersDoc);
if (!yyjson_is_obj(parametersRoot)) {
throw HttpServerException(400, "The `parameters` parameter must be an object");
}

case_insensitive_map_t<BoundParameterData> named_values;

size_t idx, max;
yyjson_val *parameterKeyVal, *parameterVal;
yyjson_obj_foreach(parametersRoot, idx, max, parameterKeyVal, parameterVal) {
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));

named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
}

return named_values;
}

case_insensitive_map_t<BoundParameterData> ExtractQueryParametersWrapper(const duckdb_httplib_openssl::Request& req) {
yyjson_doc *parametersDoc = nullptr;

try {
auto parametersJson = req.get_param_value("parameters");
auto parametersJsonCStr = parametersJson.c_str();
parametersDoc = yyjson_read(parametersJsonCStr, strlen(parametersJsonCStr), 0);
return ExtractQueryParameters(parametersDoc);
}
catch (const Exception& exception) {
yyjson_doc_free(parametersDoc);

throw exception;
}
}

// Execute query (optionally using a prepared statement)
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
const duckdb_httplib_openssl::Request& req,
const std::string& query
) {
Connection con(*global_state.db_instance);
std::unique_ptr<MaterializedQueryResult> result;

if (req.has_param("parameters")) {
auto prepared_stmt = con.Prepare(query);
if (prepared_stmt->HasError()) {
throw HttpServerException(500, prepared_stmt->GetError());
}

auto named_values = ExtractQueryParametersWrapper(req);

auto prepared_stmt_result = prepared_stmt->Execute(named_values);
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
} else {
result = con.Query(query);
}

if (result->HasError()) {
throw HttpServerException(500, result->GetError());
}

return result;
}

std::string ExtractQuery(const duckdb_httplib_openssl::Request& req) {
// Check if the query is in the URL parameters
if (req.has_param("query")) {
query = req.get_param_value("query");
return req.get_param_value("query");
}
else if (req.has_param("q")) {
query = req.get_param_value("q");
return req.get_param_value("q");
}

// If not in URL, and it's a POST request, check the body
else if (req.method == "POST" && !req.body.empty()) {
query = req.body;
}
// If no query found, return an error
else {
res.status = 200;
res.set_content(reinterpret_cast<char const*>(playgroundContent), "text/html");
return;
return req.body;
}

// Set default format to JSONCompact
throw HttpServerException(200, reinterpret_cast<char const*>(playgroundContent));
}

std::string ExtractFormat(const duckdb_httplib_openssl::Request& req) {
std::string format = "JSONEachRow";

// Check for format in URL parameter or header
Expand All @@ -271,24 +359,39 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
format = req.get_header_value("format");
}

return format;
}

// Handle both GET and POST requests
void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
CheckAuthentication(req);

// CORS allow
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Max-Age", "86400");

// Handle preflight OPTIONS request
if (req.method == "OPTIONS") {
res.status = 204; // No content
return;
}

auto query = ExtractQuery(req);
auto format = ExtractFormat(req);

try {
if (!global_state.db_instance) {
throw IOException("Database instance not initialized");
}

Connection con(*global_state.db_instance);
auto start = std::chrono::system_clock::now();
auto result = con.Query(query);
auto result = ExecuteQuery(req, query);
auto end = std::chrono::system_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);

if (result->HasError()) {
res.status = 500;
res.set_content(result->GetError(), "text/plain");
return;
}


ReqStats stats{
static_cast<float>(elapsed.count()) / 1000,
0,
Expand All @@ -308,7 +411,12 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
res.set_content(json_output, "application/x-ndjson");
}

} catch (const Exception& ex) {
}
catch (const HttpServerException& ex) {
res.status = ex.status;
res.set_content(ex.message, "text/plain");
}
catch (const Exception& ex) {
res.status = 500;
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
res.set_content(error_message, "text/plain");
Expand All @@ -325,9 +433,9 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t
global_state.is_running = true;
global_state.auth_token = auth.GetString();

// Custom basepath, defaults to root /
// Custom basepath, defaults to root /
const char* base_path_env = std::getenv("DUCKDB_HTTPSERVER_BASEPATH");
std::string base_path = "/";
std::string base_path = "/";

if (base_path_env && base_path_env[0] == '/' && strlen(base_path_env) > 1) {
base_path = std::string(base_path_env);
Expand Down

0 comments on commit 43152cb

Please sign in to comment.