diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index d3871f20f57..c539fc6a36b 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -12,6 +12,8 @@ #include "common/string_format.h" #include "common/utils.h" #include "function/cast/functions/cast_from_string_functions.h" +#include "function/rewrite_function.h" +#include "function/schema/vector_node_rel_functions.h" #include "main/client_context.h" using namespace kuzu::common; @@ -281,7 +283,8 @@ std::shared_ptr Binder::createNonRecursiveQueryRel(const std::str fields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID()); fields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID()); // Bind internal expressions. - queryRel->setLabelExpression(expressionBinder.bindLabelFunction(*queryRel)); + auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryRel}); + queryRel->setLabelExpression(function::LabelFunction::rewriteFunc(input)); fields.emplace_back(InternalKeyword::LABEL, queryRel->getLabelExpression()->getDataType().copy()); // Bind properties. @@ -576,7 +579,8 @@ std::shared_ptr Binder::createQueryNode(const std::string& parse // Bind internal expressions queryNode->setInternalID( PropertyExpression::construct(LogicalType::INTERNAL_ID(), InternalKeyword::ID, *queryNode)); - queryNode->setLabelExpression(expressionBinder.bindLabelFunction(*queryNode)); + auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryNode}); + queryNode->setLabelExpression(function::LabelFunction::rewriteFunc(input)); fieldNames.emplace_back(InternalKeyword::ID); fieldNames.emplace_back(InternalKeyword::LABEL); fieldTypes.push_back(queryNode->getInternalID()->getDataType().copy()); diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 7644545d89d..3c2c752db1c 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -1,6 +1,5 @@ #include "binder/binder.h" #include "binder/expression/aggregate_function_expression.h" -#include "binder/expression/expression_util.h" #include "binder/expression/scalar_function_expression.h" #include "binder/expression_binder.h" #include "catalog/catalog.h" @@ -10,8 +9,6 @@ #include "function/cast/vector_cast_functions.h" #include "function/rewrite_function.h" #include "function/scalar_macro_function.h" -#include "function/schema/vector_label_functions.h" -#include "function/schema/vector_node_rel_functions.h" #include "main/client_context.h" #include "parser/expression/parsed_function_expression.h" #include "parser/parsed_expression_visitor.h" @@ -27,10 +24,6 @@ namespace binder { std::shared_ptr ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) { auto funcExpr = expr.constPtrCast(); auto functionName = funcExpr->getNormalizedFunctionName(); - auto result = rewriteFunctionExpression(expr, functionName); - if (result != nullptr) { - return result; - } auto entry = context->getCatalog()->getFunctionEntry(context->getTransaction(), functionName); switch (entry->getType()) { case CatalogEntryType::SCALAR_FUNCTION_ENTRY: @@ -89,8 +82,9 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( } expression_vector childrenAfterCast; std::unique_ptr bindData; + auto bindInput = ScalarBindFuncInput{children, function.get(), context}; if (functionName == CastAnyFunction::name) { - bindData = function->bindFunc({children, function.get(), context}); + bindData = function->bindFunc(bindInput); if (bindData == nullptr) { // No need to cast. // TODO(Xiyang): We should return a deep copy otherwise the same expression might // appear in the final projection list repeatedly. @@ -104,7 +98,7 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( childrenAfterCast.push_back(std::move(childAfterCast)); } else { if (function->bindFunc) { - bindData = function->bindFunc({children, function.get(), context}); + bindData = function->bindFunc(bindInput); } else { bindData = std::make_unique(LogicalType(function->returnTypeID)); } @@ -142,7 +136,8 @@ std::shared_ptr ExpressionBinder::bindRewriteFunctionExpression( entry->ptrCast()); auto function = match->constPtrCast(); KU_ASSERT(function->rewriteFunc != nullptr); - return function->rewriteFunc(children, this); + auto input = RewriteFunctionBindInput(context, this, children); + return function->rewriteFunc(input); } std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( @@ -173,7 +168,8 @@ std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( } std::unique_ptr bindData; if (function.bindFunc) { - bindData = function.bindFunc({children, &function, context}); + auto bindInput = ScalarBindFuncInput{children, &function, context}; + bindData = function.bindFunc(bindInput); } else { bindData = std::make_unique(LogicalType(function.returnTypeID)); } @@ -207,111 +203,5 @@ std::shared_ptr ExpressionBinder::bindMacroExpression( return bindExpression(*macroParameterReplacer->visit(std::move(macroExpr))); } -// Function rewriting happens when we need to expose internal property access through function so -// that it becomes read-only or the function involves catalog information. Currently we write -// Before | After -// LABEL(a) | LIST_EXTRACT(offset(a), [table names from catalog]) -// STARTNODE(a) | a._src -// ENDNODE(a) | a._dst -std::shared_ptr ExpressionBinder::rewriteFunctionExpression( - const parser::ParsedExpression& parsedExpression, const std::string& functionName) { - if (functionName == LabelFunction::name) { - auto child = bindExpression(*parsedExpression.getChild(0)); - ExpressionUtil::validateDataType(*child, - std::vector{LogicalTypeID::NODE, LogicalTypeID::REL}); - return bindLabelFunction(*child); - } else if (functionName == StartNodeFunction::name) { - auto child = bindExpression(*parsedExpression.getChild(0)); - ExpressionUtil::validateDataType(*child, LogicalTypeID::REL); - return bindStartNodeExpression(*child); - } else if (functionName == EndNodeFunction::name) { - auto child = bindExpression(*parsedExpression.getChild(0)); - ExpressionUtil::validateDataType(*child, LogicalTypeID::REL); - return bindEndNodeExpression(*child); - } - return nullptr; -} - -std::shared_ptr ExpressionBinder::bindStartNodeExpression( - const Expression& expression) { - return expression.constCast().getSrcNode(); -} - -std::shared_ptr ExpressionBinder::bindEndNodeExpression(const Expression& expression) { - return expression.constCast().getDstNode(); -} - -static std::vector> populateLabelValues(const main::ClientContext* context, - const std::vector& entries) { - std::unordered_map map; - table_id_t maxTableID = 0; - for (auto& entry : entries) { - map.insert({entry->getTableID(), - entry->getLabel(context->getCatalog(), context->getTransaction())}); - if (entry->getTableID() > maxTableID) { - maxTableID = entry->getTableID(); - } - } - std::vector> labels; - labels.resize(maxTableID + 1); - for (auto i = 0u; i < labels.size(); ++i) { - if (map.contains(i)) { - labels[i] = std::make_unique(LogicalType::STRING(), map.at(i)); - } else { - // TODO(Xiyang/Guodong): change to null literal once we support null in LIST type. - labels[i] = std::make_unique(LogicalType::STRING(), std::string("")); - } - } - return labels; -} - -std::shared_ptr ExpressionBinder::bindLabelFunction( - const Expression& expression) const { - auto listType = LogicalType::LIST(LogicalType::STRING()); - auto catalog = context->getCatalog(); - auto transaction = context->getTransaction(); - expression_vector children; - switch (expression.getDataType().getLogicalTypeID()) { - case LogicalTypeID::NODE: { - auto& node = expression.constCast(); - if (node.isEmpty()) { - return createLiteralExpression(""); - } - if (!node.isMultiLabeled()) { - return createLiteralExpression(Value(LogicalType::STRING(), - node.getSingleEntry()->getLabel(catalog, transaction))); - } - children.push_back(node.getInternalID()); - auto labelsValue = - Value(std::move(listType), populateLabelValues(context, node.getEntries())); - children.push_back(createLiteralExpression(labelsValue)); - } break; - case LogicalTypeID::REL: { - auto& rel = expression.constCast(); - if (rel.isEmpty()) { - return createLiteralExpression(""); - } - if (!rel.isMultiLabeled()) { - return createLiteralExpression( - Value(LogicalType::STRING(), rel.getSingleEntry()->getLabel(catalog, transaction))); - } - children.push_back(rel.getInternalIDProperty()); - auto labelsValue = - Value(std::move(listType), populateLabelValues(context, rel.getEntries())); - children.push_back(createLiteralExpression(labelsValue)); - } break; - default: - KU_UNREACHABLE; - } - auto function = std::make_unique(LabelFunction::name, - std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, - LogicalTypeID::STRING, LabelFunction::execFunction); - auto bindData = std::make_unique(LogicalType::STRING()); - auto uniqueExpressionName = - ScalarFunctionExpression::getUniqueName(LabelFunction::name, children); - return std::make_shared(ExpressionType::FUNCTION, std::move(function), - std::move(bindData), std::move(children), uniqueExpressionName); -} - } // namespace binder } // namespace kuzu diff --git a/src/expression_evaluator/function_evaluator.cpp b/src/expression_evaluator/function_evaluator.cpp index ebaaa8b9656..29a19dcd923 100644 --- a/src/expression_evaluator/function_evaluator.cpp +++ b/src/expression_evaluator/function_evaluator.cpp @@ -61,7 +61,6 @@ bool FunctionExpressionEvaluator::select(SelectionVector& selVector) { // the correctness, because when all children are flat the check is done on tuple. auto pos = resultVector->state->getSelVector()[0]; return resultVector->isNull(pos) ? 0 : resultVector->getValue(pos); - ; } } return function->selectFunc(parameters, selVector); diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index fa3b9220653..4a680bec3b3 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -186,6 +186,8 @@ FunctionCollection* FunctionCollection::getFunctions() { // Node/rel functions SCALAR_FUNCTION(OffsetFunction), REWRITE_FUNCTION(IDFunction), + REWRITE_FUNCTION(StartNodeFunction), REWRITE_FUNCTION(EndNodeFunction), + REWRITE_FUNCTION(LabelFunction), // Path functions SCALAR_FUNCTION(NodesFunction), SCALAR_FUNCTION(RelsFunction), diff --git a/src/function/path/length_function.cpp b/src/function/path/length_function.cpp index f7ff67b733d..d26e619cecf 100644 --- a/src/function/path/length_function.cpp +++ b/src/function/path/length_function.cpp @@ -12,10 +12,10 @@ using namespace kuzu::common; namespace kuzu { namespace function { -static std::shared_ptr rewriteFunc(const expression_vector& params, - ExpressionBinder* binder) { - KU_ASSERT(params.size() == 1); - auto param = params[0].get(); +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + auto binder = input.expressionBinder; if (param->expressionType == ExpressionType::PATH) { int64_t numRels = 0u; std::vector recursiveRels; diff --git a/src/function/pattern/CMakeLists.txt b/src/function/pattern/CMakeLists.txt index 4857f7f50d3..295bdef1bec 100644 --- a/src/function/pattern/CMakeLists.txt +++ b/src/function/pattern/CMakeLists.txt @@ -1,6 +1,8 @@ add_library(kuzu_function_pattern OBJECT - id_function.cpp) + id_function.cpp + label_function.cpp + start_end_node_function.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/function/pattern/id_function.cpp b/src/function/pattern/id_function.cpp index 11df6dcdbe5..f1a0946f8a7 100644 --- a/src/function/pattern/id_function.cpp +++ b/src/function/pattern/id_function.cpp @@ -2,7 +2,6 @@ #include "binder/expression/node_expression.h" #include "binder/expression/rel_expression.h" #include "binder/expression_binder.h" -#include "common/types/value/value.h" #include "function/rewrite_function.h" #include "function/schema/vector_node_rel_functions.h" #include "function/struct/vector_struct_functions.h" @@ -13,10 +12,9 @@ using namespace kuzu::binder; namespace kuzu { namespace function { -static std::shared_ptr rewriteFunc(const expression_vector& params, - ExpressionBinder* binder) { - KU_ASSERT(params.size() == 1); - auto param = params[0].get(); +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); if (ExpressionUtil::isNodePattern(*param)) { auto node = param->constPtrCast(); return node->getInternalID(); @@ -26,10 +24,9 @@ static std::shared_ptr rewriteFunc(const expression_vector& params, return rel->getPropertyExpression(InternalKeyword::ID); } // Bind as struct_extract(param, "_id") - auto keyExpr = - binder->createLiteralExpression(Value(LogicalType::STRING(), InternalKeyword::ID)); - auto newParams = expression_vector{params[0], keyExpr}; - return binder->bindScalarFunctionExpression(newParams, StructExtractFunctions::name); + auto extractKey = input.expressionBinder->createLiteralExpression(InternalKeyword::ID); + return input.expressionBinder->bindScalarFunctionExpression({input.arguments[0], extractKey}, + StructExtractFunctions::name); } function_set IDFunction::getFunctionSet() { diff --git a/src/function/pattern/label_function.cpp b/src/function/pattern/label_function.cpp new file mode 100644 index 00000000000..6665f46400e --- /dev/null +++ b/src/function/pattern/label_function.cpp @@ -0,0 +1,123 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/node_expression.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression/scalar_function_expression.h" +#include "binder/expression_binder.h" +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "function/binary_function_executor.h" +#include "function/list/functions/list_extract_function.h" +#include "function/rewrite_function.h" +#include "function/scalar_function.h" +#include "function/schema/vector_node_rel_functions.h" +#include "function/struct/vector_struct_functions.h" +#include "main/client_context.h" + +using namespace kuzu::common; +using namespace kuzu::binder; +using namespace kuzu::catalog; + +namespace kuzu { +namespace function { + +struct Label { + static inline void operation(common::internalID_t& left, common::list_entry_t& right, + common::ku_string_t& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& resultVector, uint64_t resPos) { + KU_ASSERT(left.tableID < right.size); + ListExtract::operation(right, left.tableID + 1 /* listExtract requires 1-based index */, + result, rightVector, leftVector, resultVector, resPos); + } +}; + +static void execFunction(const std::vector>& params, + common::ValueVector& result, void* /*dataPtr*/ = nullptr) { + KU_ASSERT(params.size() == 2); + BinaryFunctionExecutor::executeListExtract(*params[0], *params[1], result); +} + +static std::shared_ptr getLabelsAsLiteral(main::ClientContext* context, + std::vector entries, binder::ExpressionBinder* expressionBinder) { + std::unordered_map map; + table_id_t maxTableID = 0; + for (auto& entry : entries) { + map.insert({entry->getTableID(), + entry->getLabel(context->getCatalog(), context->getTransaction())}); + if (entry->getTableID() > maxTableID) { + maxTableID = entry->getTableID(); + } + } + std::vector> labels; + labels.resize(maxTableID + 1); + for (auto i = 0u; i < labels.size(); ++i) { + if (map.contains(i)) { + labels[i] = std::make_unique(LogicalType::STRING(), map.at(i)); + } else { + labels[i] = std::make_unique(LogicalType::STRING(), std::string("")); + } + } + auto labelsValue = Value(LogicalType::LIST(LogicalType::STRING()), std::move(labels)); + return expressionBinder->createLiteralExpression(labelsValue); +} + +std::shared_ptr LabelFunction::rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto argument = input.arguments[0].get(); + auto expressionBinder = input.expressionBinder; + auto context = input.context; + expression_vector children; + if (argument->expressionType == ExpressionType::VARIABLE) { + children.push_back(input.arguments[0]); + children.push_back(expressionBinder->createLiteralExpression(InternalKeyword::LABEL)); + return expressionBinder->bindScalarFunctionExpression(children, + StructExtractFunctions::name); + } + if (ExpressionUtil::isNodePattern(*argument)) { + auto& node = argument->constCast(); + if (node.isEmpty()) { + return expressionBinder->createLiteralExpression(""); + } + if (!node.isMultiLabeled()) { + auto label = + node.getSingleEntry()->getLabel(context->getCatalog(), context->getTransaction()); + return expressionBinder->createLiteralExpression(label); + } + children.push_back(node.getInternalID()); + children.push_back(getLabelsAsLiteral(context, node.getEntries(), expressionBinder)); + } else if (ExpressionUtil::isRelPattern(*argument)) { + auto& rel = argument->constCast(); + if (rel.isEmpty()) { + return expressionBinder->createLiteralExpression(""); + } + if (!rel.isMultiLabeled()) { + auto label = + rel.getSingleEntry()->getLabel(context->getCatalog(), context->getTransaction()); + return expressionBinder->createLiteralExpression(label); + } + children.push_back(rel.getInternalIDProperty()); + children.push_back(getLabelsAsLiteral(context, rel.getEntries(), expressionBinder)); + } + KU_ASSERT(children.size() == 2); + auto function = std::make_unique(LabelFunction::name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, execFunction); + auto bindData = std::make_unique(LogicalType::STRING()); + auto uniqueName = ScalarFunctionExpression::getUniqueName(LabelFunction::name, children); + return std::make_shared(ExpressionType::FUNCTION, std::move(function), + std::move(bindData), std::move(children), uniqueName); +} + +function_set LabelFunction::getFunctionSet() { + function_set set; + auto inputTypes = + std::vector{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT}; + for (auto& inputType : inputTypes) { + auto function = std::make_unique(name, + std::vector{inputType}, rewriteFunc); + set.push_back(std::move(function)); + } + return set; +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/pattern/start_end_node_function.cpp b/src/function/pattern/start_end_node_function.cpp new file mode 100644 index 00000000000..27411063e72 --- /dev/null +++ b/src/function/pattern/start_end_node_function.cpp @@ -0,0 +1,53 @@ +#include "binder/expression/expression_util.h" +#include "binder/expression/rel_expression.h" +#include "binder/expression_binder.h" +#include "function/rewrite_function.h" +#include "function/schema/vector_node_rel_functions.h" +#include "function/struct/vector_struct_functions.h" + +using namespace kuzu::common; +using namespace kuzu::binder; + +namespace kuzu { +namespace function { + +static std::shared_ptr startRewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + if (ExpressionUtil::isRelPattern(*param)) { + return param->constCast().getSrcNode(); + } + auto extractKey = input.expressionBinder->createLiteralExpression(InternalKeyword::SRC); + return input.expressionBinder->bindScalarFunctionExpression({input.arguments[0], extractKey}, + StructExtractFunctions::name); +} + +function_set StartNodeFunction::getFunctionSet() { + function_set set; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::REL}, startRewriteFunc); + set.push_back(std::move(function)); + return set; +} + +static std::shared_ptr endRewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); + auto param = input.arguments[0].get(); + if (ExpressionUtil::isRelPattern(*param)) { + return param->constCast().getDstNode(); + } + auto extractKey = input.expressionBinder->createLiteralExpression(InternalKeyword::DST); + return input.expressionBinder->bindScalarFunctionExpression({input.arguments[0], extractKey}, + StructExtractFunctions::name); +} + +function_set EndNodeFunction::getFunctionSet() { + function_set set; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::REL}, endRewriteFunc); + set.push_back(std::move(function)); + return set; +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/struct/keys_function.cpp b/src/function/struct/keys_function.cpp index 399c81bbad2..89dbac4bc8a 100644 --- a/src/function/struct/keys_function.cpp +++ b/src/function/struct/keys_function.cpp @@ -9,13 +9,12 @@ using namespace kuzu::binder; namespace kuzu { namespace function { -static std::shared_ptr rewriteFunc(const expression_vector& params, - ExpressionBinder* /*binder*/) { - KU_ASSERT(params.size() == 1); +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 1); auto uniqueExpressionName = - ScalarFunctionExpression::getUniqueName(KeysFunctions::name, params); + ScalarFunctionExpression::getUniqueName(KeysFunctions::name, input.arguments); const auto& resultType = LogicalType::LIST(LogicalType::STRING()); - auto fields = common::StructType::getFieldNames(params[0]->dataType); + auto fields = common::StructType::getFieldNames(input.arguments[0]->dataType); std::vector> children; for (auto field : fields) { if (field == InternalKeyword::ID || field == InternalKeyword::LABEL || diff --git a/src/function/utility/nullif.cpp b/src/function/utility/nullif.cpp index 1704d1e3d2e..e0c58f19674 100644 --- a/src/function/utility/nullif.cpp +++ b/src/function/utility/nullif.cpp @@ -10,15 +10,15 @@ using namespace kuzu::common; namespace kuzu { namespace function { -static std::shared_ptr rewriteFunc(const expression_vector& params, - ExpressionBinder* binder) { - KU_ASSERT(params.size() == 2); +static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input) { + KU_ASSERT(input.arguments.size() == 2); auto uniqueExpressionName = - ScalarFunctionExpression::getUniqueName(NullIfFunction::name, params); - const auto& resultType = params[0]->getDataType(); - auto caseExpression = - std::make_shared(resultType.copy(), params[0], uniqueExpressionName); - auto whenExpression = binder->bindComparisonExpression(ExpressionType::EQUALS, params); + ScalarFunctionExpression::getUniqueName(NullIfFunction::name, input.arguments); + const auto& resultType = input.arguments[0]->getDataType(); + auto caseExpression = std::make_shared(resultType.copy(), input.arguments[0], + uniqueExpressionName); + auto binder = input.expressionBinder; + auto whenExpression = binder->bindComparisonExpression(ExpressionType::EQUALS, input.arguments); auto thenExpression = binder->createNullLiteralExpression(); thenExpression = binder->implicitCastIfNecessary(thenExpression, resultType.copy()); caseExpression->addCaseAlternative(whenExpression, thenExpression); diff --git a/src/function/vector_node_rel_functions.cpp b/src/function/vector_node_rel_functions.cpp index dc223395731..1d3afb9dc03 100644 --- a/src/function/vector_node_rel_functions.cpp +++ b/src/function/vector_node_rel_functions.cpp @@ -10,17 +10,16 @@ using namespace kuzu::common; namespace kuzu { namespace function { -static void OffsetExecFunc(const std::vector>& params, - ValueVector& result, void* /*dataPtr*/ = nullptr) { +static void execFunc(const std::vector>& params, ValueVector& result, + void* /*dataPtr*/ = nullptr) { KU_ASSERT(params.size() == 1); UnaryFunctionExecutor::execute(*params[0], result); } function_set OffsetFunction::getFunctionSet() { function_set functionSet; - functionSet.push_back( - make_unique(name, std::vector{LogicalTypeID::INTERNAL_ID}, - LogicalTypeID::INT64, OffsetExecFunc)); + functionSet.push_back(make_unique(name, + std::vector{LogicalTypeID::INTERNAL_ID}, LogicalTypeID::INT64, execFunc)); return functionSet; } diff --git a/src/include/binder/expression_binder.h b/src/include/binder/expression_binder.h index 8d6116534e7..ab2f235e107 100644 --- a/src/include/binder/expression_binder.h +++ b/src/include/binder/expression_binder.h @@ -9,6 +9,10 @@ namespace main { class ClientContext; } +namespace function { +struct Function; +} + namespace binder { class Binder; @@ -74,12 +78,6 @@ class ExpressionBinder { std::shared_ptr bindMacroExpression( const parser::ParsedExpression& parsedExpression, const std::string& macroName); - std::shared_ptr rewriteFunctionExpression( - const parser::ParsedExpression& parsedExpression, const std::string& functionName); - static std::shared_ptr bindStartNodeExpression(const Expression& expression); - static std::shared_ptr bindEndNodeExpression(const Expression& expression); - std::shared_ptr bindLabelFunction(const Expression& expression) const; - // Parameter expressions. std::shared_ptr bindParameterExpression( const parser::ParsedExpression& parsedExpression); diff --git a/src/include/function/function.h b/src/include/function/function.h index de298a9045e..b34f68b066c 100644 --- a/src/include/function/function.h +++ b/src/include/function/function.h @@ -47,9 +47,9 @@ struct ScalarBindFuncInput { Function* definition; main::ClientContext* context; - ScalarBindFuncInput(const binder::expression_vector& expressionVectors, Function* definition, + ScalarBindFuncInput(const binder::expression_vector& arguments, Function* definition, main::ClientContext* context) - : arguments{expressionVectors}, definition{definition}, context{context} {} + : arguments{arguments}, definition{definition}, context{context} {} }; using scalar_bind_func = diff --git a/src/include/function/rewrite_function.h b/src/include/function/rewrite_function.h index 702e6e220ae..f7bbf7a7c13 100644 --- a/src/include/function/rewrite_function.h +++ b/src/include/function/rewrite_function.h @@ -8,9 +8,19 @@ class ExpressionBinder; } namespace function { +struct RewriteFunctionBindInput { + main::ClientContext* context; + binder::ExpressionBinder* expressionBinder; + binder::expression_vector arguments; + + RewriteFunctionBindInput(main::ClientContext* context, + binder::ExpressionBinder* expressionBinder, binder::expression_vector arguments) + : context{context}, expressionBinder{expressionBinder}, arguments{std::move(arguments)} {} +}; + // Rewrite function to a different expression, e.g. id(n) -> n._id. -using rewrite_func_rewrite_t = std::function( - const binder::expression_vector&, binder::ExpressionBinder*)>; +using rewrite_func_rewrite_t = + std::function(const RewriteFunctionBindInput&)>; // We write for the following functions // ID(n) -> n._id diff --git a/src/include/function/schema/label_functions.h b/src/include/function/schema/label_functions.h deleted file mode 100644 index 5ddbf445d13..00000000000 --- a/src/include/function/schema/label_functions.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include "function/list/functions/list_extract_function.h" - -namespace kuzu { -namespace function { - -struct Label { - static inline void operation(common::internalID_t& left, common::list_entry_t& right, - common::ku_string_t& result, common::ValueVector& leftVector, - common::ValueVector& rightVector, common::ValueVector& resultVector, uint64_t resPos) { - KU_ASSERT(left.tableID < right.size); - ListExtract::operation(right, left.tableID + 1 /* listExtract requires 1-based index */, - result, rightVector, leftVector, resultVector, resPos); - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/schema/vector_label_functions.h b/src/include/function/schema/vector_label_functions.h deleted file mode 100644 index 5152ab1d615..00000000000 --- a/src/include/function/schema/vector_label_functions.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include "function/binary_function_executor.h" -#include "function/schema/label_functions.h" - -namespace kuzu { -namespace function { - -struct LabelFunction { - static constexpr const char* name = "LABEL"; - - static void execFunction(const std::vector>& params, - common::ValueVector& result, void* /*dataPtr*/ = nullptr) { - KU_ASSERT(params.size() == 2); - BinaryFunctionExecutor::executeListExtract(*params[0], *params[1], result); - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/schema/vector_node_rel_functions.h b/src/include/function/schema/vector_node_rel_functions.h index 454ff89843d..ec6908136bb 100644 --- a/src/include/function/schema/vector_node_rel_functions.h +++ b/src/include/function/schema/vector_node_rel_functions.h @@ -5,6 +5,8 @@ namespace kuzu { namespace function { +struct RewriteFunctionBindInput; + struct OffsetFunction { static constexpr const char* name = "OFFSET"; @@ -19,10 +21,21 @@ struct IDFunction { struct StartNodeFunction { static constexpr const char* name = "START_NODE"; + + static function_set getFunctionSet(); }; struct EndNodeFunction { static constexpr const char* name = "END_NODE"; + + static function_set getFunctionSet(); +}; + +struct LabelFunction { + static constexpr const char* name = "LABEL"; + + static function_set getFunctionSet(); + static std::shared_ptr rewriteFunc(const RewriteFunctionBindInput& input); }; } // namespace function diff --git a/test/test_files/function/lambda/list_transform.test b/test/test_files/function/lambda/list_transform.test index a2c2e434175..6070fc1b393 100644 --- a/test/test_files/function/lambda/list_transform.test +++ b/test/test_files/function/lambda/list_transform.test @@ -4,6 +4,12 @@ -CASE ListTransform +-STATEMENT MATCH p = (a:person)-[:knows*1..2]->(b) WHERE a.ID = 0 AND b.ID=2 RETURN LIST_TRANSFORM(rels(p), x->label(x)), LIST_TRANSFORM(nodes(p), x->label(x)) +---- 3 +[knows,knows]|[person,person,person] +[knows,knows]|[person,person,person] +[knows]|[person,person] + -STATEMENT RETURN list_transform(list_transform([1,3,2], x->x+5), y->y+3) ---- 1 [9,11,10] diff --git a/test/test_files/function/start_end_node.test b/test/test_files/function/start_end_node.test index b8c0ed332e5..a506392a3c3 100644 --- a/test/test_files/function/start_end_node.test +++ b/test/test_files/function/start_end_node.test @@ -55,7 +55,8 @@ -LOG StartNodeTestRecursiveRel -STATEMENT MATCH (:person { fName: "Alice" })-[r:knows *1..2]->(friend:person) RETURN START_NODE(r) ---- error -Binder exception: r has data type RECURSIVE_REL but REL was expected. +Binder exception: Cannot match a built-in function for given function START_NODE(RECURSIVE_REL). Supported inputs are +(REL) -CASE FunctionEndNode @@ -111,4 +112,5 @@ Binder exception: r has data type RECURSIVE_REL but REL was expected. -LOG EndNodeTestRecursiveRel -STATEMENT MATCH (:person { fName: "Alice" })-[r:knows *1..2]->(friend:person) RETURN END_NODE(r) ---- error -Binder exception: r has data type RECURSIVE_REL but REL was expected. +Binder exception: Cannot match a built-in function for given function END_NODE(RECURSIVE_REL). Supported inputs are +(REL)