Skip to content

Commit

Permalink
Allow prepared statemenet parameter in CALL function (#4628)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin authored Dec 13, 2024
1 parent ca95412 commit 2bb28c2
Show file tree
Hide file tree
Showing 27 changed files with 170 additions and 59 deletions.
3 changes: 1 addition & 2 deletions extension/delta/src/function/delta_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
auto scanInput = ku_dynamic_cast<ExtraScanTableFuncBindInput*>(input->extraInput.get());
auto connector = std::make_shared<DeltaConnector>();
connector->connect("" /* inMemDB */, "" /* defaultCatalogName */, context);
input->getParam(0).validateType(LogicalTypeID::STRING);
std::string query = common::stringFormat("SELECT * FROM DELTA_SCAN('{}')",
input->getParam(0).getValue<std::string>());
input->getLiteralVal<std::string>(0));
auto result = connector->executeQuery(query + " LIMIT 1");
std::vector<LogicalType> returnTypes;
std::vector<std::string> returnColumnNames = scanInput->expectedColumnNames;
Expand Down
1 change: 1 addition & 0 deletions extension/fts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include_directories(
third_party/snowball/libstemmer)

add_subdirectory(src)
add_subdirectory(test)

add_library(fts_extension
SHARED
Expand Down
14 changes: 8 additions & 6 deletions extension/fts/src/function/create_fts_index.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/create_fts_index.h"

#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"
#include "catalog/fts_index_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/types/value/nested.h"
Expand Down Expand Up @@ -37,10 +38,11 @@ struct CreateFTSBindData final : public FTSBindData {
};

static std::vector<std::string> bindProperties(const catalog::NodeTableCatalogEntry& entry,
const common::Value& properties) {
std::shared_ptr<binder::Expression> properties) {
auto propertyValue = properties->constPtrCast<binder::LiteralExpression>()->getValue();
std::vector<std::string> result;
for (auto i = 0u; i < properties.getChildrenSize(); i++) {
auto propertyName = NestedVal::getChildVal(&properties, i)->toString();
for (auto i = 0u; i < propertyValue.getChildrenSize(); i++) {
auto propertyName = NestedVal::getChildVal(&propertyValue, i)->toString();
if (!entry.containsProperty(propertyName)) {
throw BinderException{common::stringFormat("Property: {} does not exist in table {}.",
propertyName, entry.getName())};
Expand All @@ -61,9 +63,9 @@ static void validateIndexNotExist(const main::ClientContext& context, common::ta
static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
FTSUtils::validateAutoTrx(*context, CreateFTSFunction::name);
auto indexName = input->getParam(1).toString();
auto& nodeTableEntry = FTSUtils::bindTable(input->getParam(0), context, indexName,
FTSUtils::IndexOperation::CREATE);
auto indexName = input->getLiteralVal<std::string>(1);
auto& nodeTableEntry = FTSUtils::bindTable(input->getLiteralVal<std::string>(0), context,
indexName, FTSUtils::IndexOperation::CREATE);
auto properties = bindProperties(nodeTableEntry, input->getParam(2));
validateIndexNotExist(*context, nodeTableEntry.getTableID(), indexName);
auto createFTSConfig = FTSConfig{input->optionalParams};
Expand Down
6 changes: 3 additions & 3 deletions extension/fts/src/function/drop_fts_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ using namespace kuzu::function;
static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
FTSUtils::validateAutoTrx(*context, DropFTSFunction::name);
auto indexName = input->getParam(1).toString();
auto& tableEntry =
FTSUtils::bindTable(input->getParam(0), context, indexName, FTSUtils::IndexOperation::DROP);
auto indexName = input->getLiteralVal<std::string>(1);
auto& tableEntry = FTSUtils::bindTable(input->getLiteralVal<std::string>(0), context, indexName,
FTSUtils::IndexOperation::DROP);
FTSUtils::validateIndexExistence(*context, tableEntry.getTableID(), indexName);
return std::make_unique<FTSBindData>(tableEntry.getName(), tableEntry.getTableID(), indexName,
std::vector<common::LogicalType>{}, std::vector<std::string>{});
Expand Down
10 changes: 4 additions & 6 deletions extension/fts/src/function/fts_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
namespace kuzu {
namespace fts_extension {

catalog::NodeTableCatalogEntry& FTSUtils::bindTable(const common::Value& tableName,
catalog::NodeTableCatalogEntry& FTSUtils::bindTable(const std::string& tableName,
main::ClientContext* context, std::string indexName, IndexOperation operation) {
if (!context->getCatalog()->containsTable(context->getTx(), tableName.toString())) {
throw common::BinderException{
common::stringFormat("Table {} does not exist.", tableName.toString())};
if (!context->getCatalog()->containsTable(context->getTx(), tableName)) {
throw common::BinderException{common::stringFormat("Table {} does not exist.", tableName)};
}
auto tableEntry =
context->getCatalog()->getTableCatalogEntry(context->getTx(), tableName.toString());
auto tableEntry = context->getCatalog()->getTableCatalogEntry(context->getTx(), tableName);
if (tableEntry->getTableType() != common::TableType::NODE) {
switch (operation) {
case IndexOperation::CREATE:
Expand Down
55 changes: 39 additions & 16 deletions extension/fts/src/function/query_fts_index.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "function/query_fts_index.h"

#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/parameter_expression.h"
#include "catalog/catalog.h"
#include "catalog/fts_index_catalog_entry.h"
#include "common/exception/binder.h"
Expand All @@ -21,22 +23,40 @@ using namespace kuzu::main;
using namespace kuzu::function;

struct QueryFTSBindData final : public FTSBindData {
std::string query;
std::shared_ptr<binder::Expression> query;
const FTSIndexCatalogEntry& entry;
QueryFTSConfig config;

QueryFTSBindData(std::string tableName, common::table_id_t tableID, std::string indexName,
std::string query, const FTSIndexCatalogEntry& entry, std::vector<LogicalType> returnTypes,
std::vector<std::string> returnColumnNames, QueryFTSConfig config)
std::shared_ptr<binder::Expression> query, const FTSIndexCatalogEntry& entry,
std::vector<LogicalType> returnTypes, std::vector<std::string> returnColumnNames,
QueryFTSConfig config)
: FTSBindData{std::move(tableName), tableID, std::move(indexName), std::move(returnTypes),
std::move(returnColumnNames)},
query{std::move(query)}, entry{entry}, config{std::move(config)} {}

std::string getQuery() const;

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<QueryFTSBindData>(*this);
}
};

std::string QueryFTSBindData::getQuery() const {
auto value = Value::createDefaultValue(query->dataType);
switch (query->expressionType) {
case ExpressionType::LITERAL: {
value = query->constCast<binder::LiteralExpression>().getValue();
} break;
case ExpressionType::PARAMETER: {
value = query->constCast<binder::ParameterExpression>().getValue();
} break;
default:
KU_UNREACHABLE;
}
return value.getValue<std::string>();
}

struct QueryFTSLocalState : public TableFuncLocalState {
std::unique_ptr<QueryResult> result = nullptr;
uint64_t numRowsOutput = 0;
Expand All @@ -46,10 +66,13 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
auto indexName = input->getParam(1).toString();
auto& tableEntry = FTSUtils::bindTable(input->getParam(0), context, indexName,
FTSUtils::IndexOperation::QUERY);
auto query = input->getParam(2).toString();
// For queryFTS, the table and index name must be given at compile time while the user
// can give the query at runtime.
auto indexName = binder::ExpressionUtil::getLiteralValue<std::string>(*input->getParam(1));
auto& tableEntry = FTSUtils::bindTable(
binder::ExpressionUtil::getLiteralValue<std::string>(*input->getParam(0)), context,
indexName, FTSUtils::IndexOperation::QUERY);
auto query = input->getParam(2);
FTSUtils::validateIndexExistence(*context, tableEntry.getTableID(), indexName);
auto ftsCatalogEntry =
context->getCatalog()->getIndex(context->getTx(), tableEntry.getTableID(), indexName);
Expand All @@ -68,11 +91,12 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output)
// we need a wrapper call function to CALL the actual GDS function.
auto localState = data.localState->ptrCast<QueryFTSLocalState>();
if (localState->result == nullptr) {
auto bindData = data.bindData->constPtrCast<QueryFTSBindData>();
auto numDocs = bindData->entry.getNumDocs();
auto avgDocLen = bindData->entry.getAvgDocLen();
auto bindData = data.bindData->cast<QueryFTSBindData>();
auto actualQuery = bindData.getQuery();
auto numDocs = bindData.entry.getNumDocs();
auto avgDocLen = bindData.entry.getAvgDocLen();
auto query = common::stringFormat("UNWIND tokenize('{}') AS tk RETURN COUNT(DISTINCT tk);",
bindData->query);
actualQuery);
auto numTermsInQuery = data.context->clientContext
->queryInternal(query, "" /* encodedJoin */,
false /* enumerateAllPlans */, std::nullopt /* queryID */)
Expand All @@ -88,11 +112,10 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output)
"MATCH (p:`{}`) "
"WHERE _node.docID = offset(id(p)) "
"RETURN p, score",
bindData->getTermsTableName(), bindData->getDocsTableName(),
bindData->getAppearsInTableName(), bindData->query,
bindData->entry.getFTSConfig().stemmer, bindData->getTermsTableName(),
bindData->config.k, bindData->config.b, numDocs, avgDocLen, numTermsInQuery,
bindData->config.isConjunctive ? "true" : "false", bindData->tableName);
bindData.getTermsTableName(), bindData.getDocsTableName(),
bindData.getAppearsInTableName(), actualQuery, bindData.entry.getFTSConfig().stemmer,
bindData.getTermsTableName(), bindData.config.k, bindData.config.b, numDocs, avgDocLen,
numTermsInQuery, bindData.config.isConjunctive ? "true" : "false", bindData.tableName);
localState->result = data.context->clientContext->queryInternal(query, "", false,
std::nullopt /* queryID */);
}
Expand Down
2 changes: 1 addition & 1 deletion extension/fts/src/include/function/fts_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct FTSUtils {
DROP = 2,
};

static catalog::NodeTableCatalogEntry& bindTable(const common::Value& tableName,
static catalog::NodeTableCatalogEntry& bindTable(const std::string& tableName,
main::ClientContext* context, std::string indexName, IndexOperation indexOperation);

static void validateIndexExistence(const main::ClientContext& context,
Expand Down
3 changes: 3 additions & 0 deletions extension/fts/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if (${BUILD_EXTENSION_TESTS})
add_kuzu_test(fts_prepare_test prepare_test.cpp)
endif ()
26 changes: 26 additions & 0 deletions extension/fts/test/prepare_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "main_test_helper/main_test_helper.h"

using namespace kuzu::common;

namespace kuzu {
namespace testing {

TEST_F(ApiTest, PrepareFTSTest) {
createDBAndConn();
ASSERT_TRUE(conn->query(common::stringFormat("LOAD EXTENSION '{}'",
TestHelper::appendKuzuRootPath(
"extension/fts/build/libfts.kuzu_extension")))
->isSuccess());
ASSERT_TRUE(
conn->query("CALL CREATE_FTS_INDEX('person', 'personIdx', ['fName'])")->isSuccess());
auto prepared =
conn->prepare("CALL QUERY_FTS_INDEX('person', 'personIdx', $q) RETURN node.ID, score;");
auto preparedResult = TestHelper::convertResultToString(
*conn->execute(prepared.get(), std::make_pair(std::string("q"), std::string("alice"))));
auto nonPreparedResult = TestHelper::convertResultToString(*conn->query(
"CALL QUERY_FTS_INDEX('person', 'personIdx', 'alice') RETURN node.ID, score;"));
sortAndCheckTestResults(preparedResult, nonPreparedResult);
}

} // namespace testing
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/binder/bind/bind_file_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ std::unique_ptr<BoundBaseScanSource> Binder::bindFileScanSource(const BaseScanSo
auto func = getScanFunction(config->fileTypeInfo, *config);
// Bind table function
auto bindInput = TableFuncBindInput();
bindInput.addParam(Value::createValue(filePaths[0]));
bindInput.addLiteralParam(Value::createValue(filePaths[0]));
auto extraInput = std::make_unique<ExtraScanTableFuncBindInput>();
extraInput->config = config->copy();
extraInput->expectedColumnNames = columnNames;
Expand Down
20 changes: 12 additions & 8 deletions src/binder/bind/bind_table_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ using namespace kuzu::function;
namespace kuzu {
namespace binder {

static void validateParameterType(expression_vector positionalParams) {
for (auto& param : positionalParams) {
ExpressionUtil::validateExpressionType(*param,
{ExpressionType::LITERAL, ExpressionType::PARAMETER});
}
}

BoundTableFunction Binder::bindTableFunc(std::string tableFuncName,
const parser::ParsedExpression& expr, expression_vector& columns) {
auto entry = BuiltInFunctionsUtils::getFunctionCatalogEntry(clientContext->getTx(),
Expand All @@ -31,19 +38,16 @@ BoundTableFunction Binder::bindTableFunc(std::string tableFuncName,
}
}
auto func = BuiltInFunctionsUtils::matchFunction(tableFuncName, positionalParamTypes, entry);
std::vector<Value> inputValues;
for (auto& param : positionalParams) {
ExpressionUtil::validateExpressionType(*param, ExpressionType::LITERAL);
auto literalExpr = param->constPtrCast<LiteralExpression>();
inputValues.push_back(literalExpr->getValue());
}
validateParameterType(positionalParams);
auto tableFunc = func->constPtrCast<TableFunction>();
for (auto i = 0u; i < positionalParams.size(); ++i) {
auto parameterTypeID = tableFunc->parameterTypeIDs[i];
ExpressionUtil::validateDataType(*positionalParams[i], parameterTypeID);
if (positionalParams[i]->expressionType == ExpressionType::LITERAL) {
ExpressionUtil::validateDataType(*positionalParams[i], parameterTypeID);
}
}
auto bindInput = TableFuncBindInput();
bindInput.params = std::move(inputValues);
bindInput.params = std::move(positionalParams);
bindInput.optionalParams = std::move(optionalParams);
auto bindData = tableFunc->bindFunc(clientContext, &bindInput);
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
Expand Down
16 changes: 16 additions & 0 deletions src/binder/expression/expression_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ void ExpressionUtil::validateExpressionType(const Expression& expr,
ExpressionTypeUtil::toString(expectedType)));
}

void ExpressionUtil::validateExpressionType(const Expression& expr,
std::vector<common::ExpressionType> expectedType) {
if (std::find(expectedType.begin(), expectedType.end(), expr.expressionType) !=
expectedType.end()) {
return;
}
std::string expectedTypesStr = "";
std::for_each(expectedType.begin(), expectedType.end(),
[&expectedTypesStr](common::ExpressionType type) {
expectedTypesStr += expectedTypesStr.empty() ? ExpressionTypeUtil::toString(type) :
"," + ExpressionTypeUtil::toString(type);
});
throw BinderException(stringFormat("{} has type {} but {} was expected.", expr.toString(),
ExpressionTypeUtil::toString(expr.expressionType), expectedTypesStr));
}

void ExpressionUtil::validateDataType(const Expression& expr, const LogicalType& expectedType) {
if (expr.getDataType() == expectedType) {
return;
Expand Down
3 changes: 2 additions & 1 deletion src/function/table/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ add_library(kuzu_table
OBJECT
bind_data.cpp
simple_table_functions.cpp
scan_functions.cpp)
scan_functions.cpp
bind_input.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_table>
Expand Down
25 changes: 25 additions & 0 deletions src/function/table/bind_input.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "function/table/bind_input.h"

#include "binder/expression/literal_expression.h"

namespace kuzu {
namespace function {

void TableFuncBindInput::addLiteralParam(common::Value value) {
params.push_back(std::make_shared<binder::LiteralExpression>(std::move(value), ""));
}

template<typename T>
T TableFuncBindInput::getLiteralVal(common::idx_t idx) const {
KU_ASSERT(params[idx]->expressionType == common::ExpressionType::LITERAL);
return params[idx]->constCast<binder::LiteralExpression>().getValue().getValue<T>();
}

template KUZU_API std::string TableFuncBindInput::getLiteralVal<std::string>(
common::idx_t idx) const;
template KUZU_API uint64_t TableFuncBindInput::getLiteralVal<uint64_t>(common::idx_t idx) const;
template KUZU_API uint32_t TableFuncBindInput::getLiteralVal<uint32_t>(common::idx_t idx) const;
template KUZU_API uint8_t* TableFuncBindInput::getLiteralVal<uint8_t*>(common::idx_t idx) const;

} // namespace function
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/function/table/call/current_setting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output)

static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
auto optionName = input->getParam(0).getValue<std::string>();
auto optionName = input->getLiteralVal<std::string>(0);
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
columnNames.emplace_back(optionName);
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/call/show_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
auto tableName = input->getParam(0).getValue<std::string>();
auto tableName = input->getLiteralVal<std::string>(0);
auto catalog = context->getCatalog();
auto tableID = catalog->getTableID(context->getTx(), tableName);
auto tableEntry = catalog->getTableCatalogEntry(context->getTx(), tableID);
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/call/stats_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {

static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
const auto tableName = input->getParam(0).getValue<std::string>();
const auto tableName = input->getLiteralVal<std::string>(0);
const auto catalog = context->getCatalog();
if (!catalog->containsTable(context->getTx(), tableName)) {
throw BinderException{"Table " + tableName + " does not exist!"};
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/call/storage_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
columnTypes.emplace_back(LogicalType::STRING());
columnTypes.emplace_back(LogicalType::STRING());
columnTypes.emplace_back(LogicalType::STRING());
auto tableName = input->getParam(0).getValue<std::string>();
auto tableName = input->getLiteralVal<std::string>(0);
auto catalog = context->getCatalog();
if (!catalog->containsTable(context->getTx(), tableName)) {
throw BinderException{"Table " + tableName + " does not exist!"};
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/call/table_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
TableFuncBindInput* input) {
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
auto catalogEntry = getTableCatalogEntry(context, input->getParam(0).getValue<std::string>());
auto catalogEntry = getTableCatalogEntry(context, input->getLiteralVal<std::string>(0));
auto tableEntry = catalogEntry->constPtrCast<TableCatalogEntry>();
columnNames.emplace_back("property id");
columnTypes.push_back(LogicalType::INT32());
Expand Down
2 changes: 2 additions & 0 deletions src/include/binder/expression/expression_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct KUZU_API ExpressionUtil {
static bool isEmptyList(const Expression& expression);

static void validateExpressionType(const Expression& expr, common::ExpressionType expectedType);
static void validateExpressionType(const Expression& expr,
std::vector<common::ExpressionType> expectedType);

// Validate data type.
static void validateDataType(const Expression& expr, const common::LogicalType& expectedType);
Expand Down
Loading

0 comments on commit 2bb28c2

Please sign in to comment.