generated from duckdb/extension-template
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support prepared statements and parameters
- Loading branch information
1 parent
bd381c1
commit f803e12
Showing
18 changed files
with
823 additions
and
326 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#include "httpserver_extension/http_handler/common.hpp" | ||
#include "httpserver_extension/state.hpp" | ||
#include <string> | ||
#include <vector> | ||
|
||
#define CPPHTTPLIB_OPENSSL_SUPPORT | ||
#include "httplib.hpp" | ||
|
||
namespace duckdb_httpserver { | ||
|
||
// Base64 decoding function | ||
static std::string base64_decode(const std::string &in) { | ||
std::string out; | ||
std::vector<int> T(256, -1); | ||
for (int i = 0; i < 64; i++) | ||
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i; | ||
|
||
int val = 0, valb = -8; | ||
for (unsigned char c : in) { | ||
if (T[c] == -1) break; | ||
val = (val << 6) + T[c]; | ||
valb += 6; | ||
if (valb >= 0) { | ||
out.push_back(char((val >> valb) & 0xFF)); | ||
valb -= 8; | ||
} | ||
} | ||
return out; | ||
} | ||
|
||
// Check authentication | ||
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) { | ||
if (global_state.auth_token.empty()) { | ||
return; // No authentication required if no token is set | ||
} | ||
|
||
// Check for X-API-Key header | ||
auto api_key = req.get_header_value("X-API-Key"); | ||
if (!api_key.empty() && api_key == global_state.auth_token) { | ||
return; | ||
} | ||
|
||
// Check for Basic Auth | ||
auto auth = req.get_header_value("Authorization"); | ||
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) { | ||
std::string decoded_auth = base64_decode(auth.substr(6)); | ||
if (decoded_auth == global_state.auth_token) { | ||
return; | ||
} | ||
} | ||
|
||
throw HttpHandlerException(401, "Unauthorized"); | ||
} | ||
|
||
} // namespace duckdb_httpserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#include "httpserver_extension/http_handler/common.hpp" | ||
#include "duckdb.hpp" | ||
|
||
#include "yyjson.hpp" | ||
#include <string> | ||
|
||
#define CPPHTTPLIB_OPENSSL_SUPPORT | ||
#include "httplib.hpp" | ||
|
||
using namespace duckdb; | ||
using namespace duckdb_yyjson; | ||
|
||
namespace duckdb_httpserver { | ||
|
||
static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) { | ||
if (!yyjson_is_obj(parameterVal)) { | ||
throw HttpHandlerException(400, "The parameter `" + key + "` parameter must be an object"); | ||
} | ||
|
||
auto typeVal = yyjson_obj_get(parameterVal, "type"); | ||
if (!typeVal) { | ||
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field"); | ||
} | ||
if (!yyjson_is_str(typeVal)) { | ||
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string"); | ||
} | ||
auto type = std::string(yyjson_get_str(typeVal)); | ||
|
||
auto valueVal = yyjson_obj_get(parameterVal, "value"); | ||
if (!valueVal) { | ||
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field"); | ||
} | ||
|
||
if (type == "TEXT") { | ||
if (!yyjson_is_str(valueVal)) { | ||
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string"); | ||
} | ||
|
||
return BoundParameterData(Value(yyjson_get_str(valueVal))); | ||
} | ||
else if (type == "BOOLEAN") { | ||
if (!yyjson_is_bool(valueVal)) { | ||
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean"); | ||
} | ||
|
||
return BoundParameterData(Value(bool(yyjson_get_bool(valueVal)))); | ||
} | ||
|
||
throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`"); | ||
} | ||
|
||
static case_insensitive_map_t<BoundParameterData> ExtractQueryParametersImpl(yyjson_doc* parametersDoc) { | ||
if (!parametersDoc) { | ||
throw HttpHandlerException(400, "Unable to parse the `parameters` parameter"); | ||
} | ||
|
||
auto parametersRoot = yyjson_doc_get_root(parametersDoc); | ||
if (!yyjson_is_obj(parametersRoot)) { | ||
throw HttpHandlerException(400, "The `parameters` parameter must be an object"); | ||
} | ||
|
||
case_insensitive_map_t<BoundParameterData> named_values; | ||
|
||
size_t idx, max; | ||
yyjson_val *parameterKeyVal, *parameterVal; | ||
yyjson_obj_foreach(parametersRoot, idx, max, parameterKeyVal, parameterVal) { | ||
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal)); | ||
|
||
named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal); | ||
} | ||
|
||
return named_values; | ||
} | ||
|
||
case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(const duckdb_httplib_openssl::Request& req) { | ||
yyjson_doc *parametersDoc = nullptr; | ||
|
||
try { | ||
auto parametersJson = req.get_param_value("parameters"); | ||
auto parametersJsonCStr = parametersJson.c_str(); | ||
parametersDoc = yyjson_read(parametersJsonCStr, strlen(parametersJsonCStr), 0); | ||
return ExtractQueryParametersImpl(parametersDoc); | ||
} | ||
catch (const Exception& exception) { | ||
yyjson_doc_free(parametersDoc); | ||
|
||
throw exception; | ||
} | ||
} | ||
|
||
} // namespace duckdb_httpserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
#include "httpserver_extension/http_handler/authentication.hpp" | ||
#include "httpserver_extension/http_handler/bindings.hpp" | ||
#include "httpserver_extension/http_handler/common.hpp" | ||
#include "httpserver_extension/http_handler/handler.hpp" | ||
#include "httpserver_extension/http_handler/playground.hpp" | ||
#include "httpserver_extension/http_handler/response_serializer.hpp" | ||
#include "httpserver_extension/state.hpp" | ||
#include "duckdb.hpp" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#define CPPHTTPLIB_OPENSSL_SUPPORT | ||
#include "httplib.hpp" | ||
|
||
using namespace duckdb; | ||
|
||
namespace duckdb_httpserver { | ||
|
||
// Handle both GET and POST requests | ||
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { | ||
try { | ||
CheckAuthentication(req); | ||
|
||
// CORS allow | ||
res.set_header("Access-Control-Allow-Origin", "*"); | ||
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT"); | ||
res.set_header("Access-Control-Allow-Headers", "*"); | ||
res.set_header("Access-Control-Allow-Credentials", "true"); | ||
res.set_header("Access-Control-Max-Age", "86400"); | ||
|
||
// Handle preflight OPTIONS request | ||
if (req.method == "OPTIONS") { | ||
res.status = 204; // No content | ||
return; | ||
} | ||
|
||
auto query = ExtractQuery(req); | ||
auto format = ExtractFormat(req); | ||
|
||
if (query == "") { | ||
res.status = 200; | ||
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html"); | ||
return; | ||
} | ||
|
||
if (!global_state.db_instance) { | ||
throw IOException("Database instance not initialized"); | ||
} | ||
|
||
auto start = std::chrono::system_clock::now(); | ||
auto result = ExecuteQuery(req, query); | ||
auto end = std::chrono::system_clock::now(); | ||
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); | ||
|
||
QueryExecStats stats{ | ||
static_cast<float>(elapsed.count()) / 1000, | ||
0, | ||
0 | ||
}; | ||
|
||
// Format Options | ||
if (format == "JSONEachRow") { | ||
std::string json_output = ConvertResultToNDJSON(*result); | ||
res.set_content(json_output, "application/x-ndjson"); | ||
} else if (format == "JSONCompact") { | ||
std::string json_output = ConvertResultToJSON(*result, stats); | ||
res.set_content(json_output, "application/json"); | ||
} else { | ||
// Default to NDJSON for DuckDB's own queries | ||
std::string json_output = ConvertResultToNDJSON(*result); | ||
res.set_content(json_output, "application/x-ndjson"); | ||
} | ||
|
||
} | ||
catch (const HttpHandlerException& ex) { | ||
res.status = ex.status; | ||
res.set_content(ex.message, "text/plain"); | ||
} | ||
catch (const Exception& ex) { | ||
res.status = 500; | ||
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what()); | ||
res.set_content(error_message, "text/plain"); | ||
} | ||
} | ||
|
||
// Execute query (optionally using a prepared statement) | ||
std::unique_ptr<MaterializedQueryResult> ExecuteQuery( | ||
const duckdb_httplib_openssl::Request& req, | ||
const std::string& query | ||
) { | ||
Connection con(*global_state.db_instance); | ||
std::unique_ptr<MaterializedQueryResult> result; | ||
|
||
if (req.has_param("parameters")) { | ||
auto prepared_stmt = con.Prepare(query); | ||
if (prepared_stmt->HasError()) { | ||
throw HttpHandlerException(500, prepared_stmt->GetError()); | ||
} | ||
|
||
auto named_values = ExtractQueryParameters(req); | ||
|
||
auto prepared_stmt_result = prepared_stmt->Execute(named_values); | ||
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT); | ||
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize(); | ||
} else { | ||
result = con.Query(query); | ||
} | ||
|
||
if (result->HasError()) { | ||
throw HttpHandlerException(500, result->GetError()); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
std::string ExtractQuery(const duckdb_httplib_openssl::Request& req) { | ||
// Check if the query is in the URL parameters | ||
if (req.has_param("query")) { | ||
return req.get_param_value("query"); | ||
} | ||
else if (req.has_param("q")) { | ||
return req.get_param_value("q"); | ||
} | ||
|
||
// If not in URL, and it's a POST request, check the body | ||
else if (req.method == "POST" && !req.body.empty()) { | ||
return req.body; | ||
} | ||
|
||
// std::optional is not available for this project | ||
return ""; | ||
} | ||
|
||
std::string ExtractFormat(const duckdb_httplib_openssl::Request& req) { | ||
std::string format = "JSONEachRow"; | ||
|
||
// Check for format in URL parameter or header | ||
if (req.has_param("default_format")) { | ||
format = req.get_param_value("default_format"); | ||
} else if (req.has_header("X-ClickHouse-Format")) { | ||
format = req.get_header_value("X-ClickHouse-Format"); | ||
} else if (req.has_header("format")) { | ||
format = req.get_header_value("format"); | ||
} | ||
|
||
return format; | ||
} | ||
|
||
} // namespace duckdb_httpserver |
Oops, something went wrong.