Skip to content

Commit

Permalink
skip query result from rewritten queries (#4633)
Browse files Browse the repository at this point in the history
  • Loading branch information
ray6080 authored Dec 30, 2024
1 parent f22023e commit 606eef8
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 27 deletions.
2 changes: 1 addition & 1 deletion extension/fts/test/test_files/fts_small.test
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT CALL CREATE_FTS_INDEX('doc', 'docIdx', ['content', 'author', 'name'])
---- ok
---- 0

-LOG QueryFTSConjunctiveSingleKeyword
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'alice', conjunctive := true) RETURN _node.ID, score
Expand Down
2 changes: 1 addition & 1 deletion extension/httpfs/src/http_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ HTTPConfig::HTTPConfig(main::ClientContext* context) {

void HTTPConfigEnvProvider::setOptionValue(main::ClientContext* context) {
const auto cacheFileOptionStrVal =
context->getEnvVariable(HTTPCacheFileConfig::HTTP_CACHE_FILE_ENV_VAR);
main::ClientContext::getEnvVariable(HTTPCacheFileConfig::HTTP_CACHE_FILE_ENV_VAR);
if (cacheFileOptionStrVal != "") {
bool enableCacheFile = false;
function::CastString::operation(
Expand Down
10 changes: 5 additions & 5 deletions extension/httpfs/src/s3_download_options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ void S3DownloadOptions::registerExtensionOptions(main::Database* db) {
}

void S3DownloadOptions::setEnvValue(main::ClientContext* context) {
auto accessKeyID = context->getEnvVariable(S3AccessKeyID::NAME);
auto secretAccessKey = context->getEnvVariable(S3SecretAccessKey::NAME);
auto endpoint = context->getEnvVariable(S3EndPoint::NAME);
auto urlStyle = context->getEnvVariable(S3URLStyle::NAME);
auto region = context->getEnvVariable(S3Region::NAME);
auto accessKeyID = main::ClientContext::getEnvVariable(S3AccessKeyID::NAME);
auto secretAccessKey = main::ClientContext::getEnvVariable(S3SecretAccessKey::NAME);
auto endpoint = main::ClientContext::getEnvVariable(S3EndPoint::NAME);
auto urlStyle = main::ClientContext::getEnvVariable(S3URLStyle::NAME);
auto region = main::ClientContext::getEnvVariable(S3Region::NAME);
if (accessKeyID != "") {
context->setExtensionOption(S3AccessKeyID::NAME, Value::createValue(accessKeyID));
}
Expand Down
15 changes: 7 additions & 8 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,18 @@ class KUZU_API ClientContext {
extension::ExtensionOptions* getExtensionOptions() const;
std::string getExtensionDir() const;

// Environment.
std::string getEnvVariable(const std::string& name);

// Database component getters.
std::string getDatabasePath() const;
Database* getDatabase() const { return localDatabase; }
common::TaskScheduler* getTaskScheduler() const;
DatabaseManager* getDatabaseManager() const;
storage::StorageManager* getStorageManager() const;
storage::MemoryManager* getMemoryManager();
storage::MemoryManager* getMemoryManager() const;
storage::WAL* getWAL() const;
catalog::Catalog* getCatalog() const;
transaction::TransactionManager* getTransactionManagerUnsafe() const;
common::VirtualFileSystem* getVFSUnsafe() const;
common::RandomEngine* getRandomEngine();
common::RandomEngine* getRandomEngine() const;

// Query.
std::unique_ptr<PreparedStatement> prepare(std::string_view query);
Expand All @@ -131,7 +128,7 @@ class KUZU_API ClientContext {
std::optional<uint64_t> queryID = std::nullopt);

void setDefaultDatabase(AttachedKuzuDatabase* defaultDatabase_);
bool hasDefaultDatabase();
bool hasDefaultDatabase() const;

void addScalarFunction(std::string name, function::function_set definitions);
void removeScalarFunction(std::string name);
Expand All @@ -146,10 +143,12 @@ class KUZU_API ClientContext {
std::unique_ptr<QueryResult> queryInternal(std::string_view query,
std::optional<uint64_t> queryID = std::nullopt);

static std::string getEnvVariable(const std::string& name);

private:
std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<QueryResult> queryResultWithError(std::string_view errMsg);
static std::unique_ptr<QueryResult> queryResultWithError(std::string_view errMsg);

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

Expand Down Expand Up @@ -178,7 +177,7 @@ class KUZU_API ClientContext {
std::unique_ptr<QueryResult> executeNoLock(PreparedStatement* preparedStatement,
std::optional<uint64_t> queryID = std::nullopt);

bool canExecuteWriteQuery();
bool canExecuteWriteQuery() const;

void runFuncInTransaction(const std::function<void(void)>& fun);

Expand Down
12 changes: 10 additions & 2 deletions src/include/parser/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ namespace parser {

class Statement {
public:
explicit Statement(common::StatementType statementType) : statementType{statementType} {}
explicit Statement(common::StatementType statementType)
: statementType{statementType}, internal{false} {}

virtual ~Statement() = default;

common::StatementType getStatementType() const { return statementType; }
void setToInternal() { internal = true; }
bool isInternal() const { return internal; }

bool requireTx() {
bool requireTx() const {
switch (statementType) {
case common::StatementType::TRANSACTION:
return false;
Expand All @@ -38,6 +41,11 @@ class Statement {

private:
common::StatementType statementType;
// By setting the statement to internal, we still execute the statement, but will not return the
// executio result as part of the query result returned to users.
// The use case for this is when a query internally generates other queries to finish first,
// e.g., `TableFunction::rewriteFunc`.
bool internal;
};

} // namespace parser
Expand Down
20 changes: 12 additions & 8 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ storage::StorageManager* ClientContext::getStorageManager() const {
}
}

storage::MemoryManager* ClientContext::getMemoryManager() {
storage::MemoryManager* ClientContext::getMemoryManager() const {
return localDatabase->memoryManager.get();
}

Expand Down Expand Up @@ -222,7 +222,7 @@ VirtualFileSystem* ClientContext::getVFSUnsafe() const {
return localDatabase->vfs.get();
}

RandomEngine* ClientContext::getRandomEngine() {
RandomEngine* ClientContext::getRandomEngine() const {
return randomEngine.get();
}

Expand Down Expand Up @@ -277,6 +277,9 @@ std::unique_ptr<QueryResult> ClientContext::queryInternal(std::string_view query
for (const auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(statement, false /*requireNewTx*/);
auto currentQueryResult = executeNoLock(preparedStatement.get(), queryID);
if (statement->isInternal()) {
continue;
}
if (!lastResult) {
// first result of the query
queryResult = std::move(currentQueryResult);
Expand Down Expand Up @@ -377,12 +380,13 @@ std::vector<std::shared_ptr<Statement>> ClientContext::parseQuery(std::string_vi
if (!rewriteQuery.empty()) {
auto rewrittenStatements = Parser::parseQuery(rewriteQuery, this);
for (auto& statement : rewrittenStatements) {
statement->setToInternal();
statements.push_back(statement);
}
}
statements.push_back(parsedStatements[i]);
}
} catch (std::exception& exception) {
} catch (std::exception&) {
if (startNewTrx) {
transactionContext->rollback();
}
Expand All @@ -398,7 +402,7 @@ void ClientContext::setDefaultDatabase(AttachedKuzuDatabase* defaultDatabase_) {
remoteDatabase = defaultDatabase_;
}

bool ClientContext::hasDefaultDatabase() {
bool ClientContext::hasDefaultDatabase() const {
return remoteDatabase != nullptr;
}

Expand Down Expand Up @@ -513,7 +517,7 @@ std::unique_ptr<QueryResult> ClientContext::handleFailedExecution(

// If there is an active transaction in the context, we execute the function in current active
// transaction. If there is no active transaction, we start an auto commit transaction.
void ClientContext::runFuncInTransaction(const std::function<void(void)>& fun) {
void ClientContext::runFuncInTransaction(const std::function<void()>& fun) {
// check if we are on AutoCommit. In this case we should start a transaction
bool startNewTrx = !transactionContext->hasActiveTransaction();
if (startNewTrx) {
Expand Down Expand Up @@ -543,7 +547,7 @@ void ClientContext::removeScalarFunction(std::string name) {
runFuncInTransaction([&]() { localDatabase->catalog->dropFunction(getTx(), std::move(name)); });
}

bool ClientContext::canExecuteWriteQuery() {
bool ClientContext::canExecuteWriteQuery() const {
if (dbConfig.readOnly) {
return false;
}
Expand All @@ -558,11 +562,11 @@ bool ClientContext::canExecuteWriteQuery() {
return true;
}

processor::WarningContext& ClientContext::getWarningContextUnsafe() {
WarningContext& ClientContext::getWarningContextUnsafe() {
return warningContext;
}

const processor::WarningContext& ClientContext::getWarningContext() const {
const WarningContext& ClientContext::getWarningContext() const {
return warningContext;
}

Expand Down
6 changes: 4 additions & 2 deletions test/test_runner/test_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ void TestRunner::checkPlanResult(Connection& conn, QueryResult* result, TestStat
}

void TestRunner::outputFailedPlan(Connection& conn, const TestStatement* statement) {
const auto plan = getLogicalPlan(statement->query, conn);
spdlog::error("QUERY FAILED.");
spdlog::info("PLAN: \n{}", plan->toString());
const auto plan = getLogicalPlan(statement->query, conn);
if (plan) {
spdlog::info("PLAN: \n{}", plan->toString());
}
}

bool TestRunner::checkResultNumeric(QueryResult& queryResult, const TestStatement* statement,
Expand Down

0 comments on commit 606eef8

Please sign in to comment.