Skip to content

Commit

Permalink
Move hard code function binding to rewrite function framework (#4813)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU authored Feb 1, 2025
1 parent f7eb350 commit 69c6456
Show file tree
Hide file tree
Showing 20 changed files with 261 additions and 204 deletions.
8 changes: 6 additions & 2 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "common/string_format.h"
#include "common/utils.h"
#include "function/cast/functions/cast_from_string_functions.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_node_rel_functions.h"
#include "main/client_context.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -281,7 +283,8 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
fields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID());
fields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID());
// Bind internal expressions.
queryRel->setLabelExpression(expressionBinder.bindLabelFunction(*queryRel));
auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryRel});
queryRel->setLabelExpression(function::LabelFunction::rewriteFunc(input));
fields.emplace_back(InternalKeyword::LABEL,
queryRel->getLabelExpression()->getDataType().copy());
// Bind properties.
Expand Down Expand Up @@ -576,7 +579,8 @@ std::shared_ptr<NodeExpression> Binder::createQueryNode(const std::string& parse
// Bind internal expressions
queryNode->setInternalID(
PropertyExpression::construct(LogicalType::INTERNAL_ID(), InternalKeyword::ID, *queryNode));
queryNode->setLabelExpression(expressionBinder.bindLabelFunction(*queryNode));
auto input = function::RewriteFunctionBindInput(clientContext, &expressionBinder, {queryNode});
queryNode->setLabelExpression(function::LabelFunction::rewriteFunc(input));
fieldNames.emplace_back(InternalKeyword::ID);
fieldNames.emplace_back(InternalKeyword::LABEL);
fieldTypes.push_back(queryNode->getInternalID()->getDataType().copy());
Expand Down
124 changes: 7 additions & 117 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_binder.h"
#include "catalog/catalog.h"
Expand All @@ -10,8 +9,6 @@
#include "function/cast/vector_cast_functions.h"
#include "function/rewrite_function.h"
#include "function/scalar_macro_function.h"
#include "function/schema/vector_label_functions.h"
#include "function/schema/vector_node_rel_functions.h"
#include "main/client_context.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/parsed_expression_visitor.h"
Expand All @@ -27,10 +24,6 @@ namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) {
auto funcExpr = expr.constPtrCast<ParsedFunctionExpression>();
auto functionName = funcExpr->getNormalizedFunctionName();
auto result = rewriteFunctionExpression(expr, functionName);
if (result != nullptr) {
return result;
}
auto entry = context->getCatalog()->getFunctionEntry(context->getTransaction(), functionName);
switch (entry->getType()) {
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
Expand Down Expand Up @@ -89,8 +82,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
}
expression_vector childrenAfterCast;
std::unique_ptr<function::FunctionBindData> bindData;
auto bindInput = ScalarBindFuncInput{children, function.get(), context};
if (functionName == CastAnyFunction::name) {
bindData = function->bindFunc({children, function.get(), context});
bindData = function->bindFunc(bindInput);
if (bindData == nullptr) { // No need to cast.
// TODO(Xiyang): We should return a deep copy otherwise the same expression might
// appear in the final projection list repeatedly.
Expand All @@ -104,7 +98,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
childrenAfterCast.push_back(std::move(childAfterCast));
} else {
if (function->bindFunc) {
bindData = function->bindFunc({children, function.get(), context});
bindData = function->bindFunc(bindInput);
} else {
bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
}
Expand Down Expand Up @@ -142,7 +136,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindRewriteFunctionExpression(
entry->ptrCast<FunctionCatalogEntry>());
auto function = match->constPtrCast<RewriteFunction>();
KU_ASSERT(function->rewriteFunc != nullptr);
return function->rewriteFunc(children, this);
auto input = RewriteFunctionBindInput(context, this, children);
return function->rewriteFunc(input);
}

std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
Expand Down Expand Up @@ -173,7 +168,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
}
std::unique_ptr<FunctionBindData> bindData;
if (function.bindFunc) {
bindData = function.bindFunc({children, &function, context});
auto bindInput = ScalarBindFuncInput{children, &function, context};
bindData = function.bindFunc(bindInput);
} else {
bindData = std::make_unique<function::FunctionBindData>(LogicalType(function.returnTypeID));
}
Expand Down Expand Up @@ -207,111 +203,5 @@ std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
return bindExpression(*macroParameterReplacer->visit(std::move(macroExpr)));
}

// Function rewriting happens when we need to expose internal property access through function so
// that it becomes read-only or the function involves catalog information. Currently we write
// Before | After
// LABEL(a) | LIST_EXTRACT(offset(a), [table names from catalog])
// STARTNODE(a) | a._src
// ENDNODE(a) | a._dst
std::shared_ptr<Expression> ExpressionBinder::rewriteFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName) {
if (functionName == LabelFunction::name) {
auto child = bindExpression(*parsedExpression.getChild(0));
ExpressionUtil::validateDataType(*child,
std::vector<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindLabelFunction(*child);
} else if (functionName == StartNodeFunction::name) {
auto child = bindExpression(*parsedExpression.getChild(0));
ExpressionUtil::validateDataType(*child, LogicalTypeID::REL);
return bindStartNodeExpression(*child);
} else if (functionName == EndNodeFunction::name) {
auto child = bindExpression(*parsedExpression.getChild(0));
ExpressionUtil::validateDataType(*child, LogicalTypeID::REL);
return bindEndNodeExpression(*child);
}
return nullptr;
}

std::shared_ptr<Expression> ExpressionBinder::bindStartNodeExpression(
const Expression& expression) {
return expression.constCast<RelExpression>().getSrcNode();
}

std::shared_ptr<Expression> ExpressionBinder::bindEndNodeExpression(const Expression& expression) {
return expression.constCast<RelExpression>().getDstNode();
}

static std::vector<std::unique_ptr<Value>> populateLabelValues(const main::ClientContext* context,
const std::vector<TableCatalogEntry*>& entries) {
std::unordered_map<table_id_t, std::string> map;
table_id_t maxTableID = 0;
for (auto& entry : entries) {
map.insert({entry->getTableID(),
entry->getLabel(context->getCatalog(), context->getTransaction())});
if (entry->getTableID() > maxTableID) {
maxTableID = entry->getTableID();
}
}
std::vector<std::unique_ptr<Value>> labels;
labels.resize(maxTableID + 1);
for (auto i = 0u; i < labels.size(); ++i) {
if (map.contains(i)) {
labels[i] = std::make_unique<Value>(LogicalType::STRING(), map.at(i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
labels[i] = std::make_unique<Value>(LogicalType::STRING(), std::string(""));
}
}
return labels;
}

std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const Expression& expression) const {
auto listType = LogicalType::LIST(LogicalType::STRING());
auto catalog = context->getCatalog();
auto transaction = context->getTransaction();
expression_vector children;
switch (expression.getDataType().getLogicalTypeID()) {
case LogicalTypeID::NODE: {
auto& node = expression.constCast<NodeExpression>();
if (node.isEmpty()) {
return createLiteralExpression("");
}
if (!node.isMultiLabeled()) {
return createLiteralExpression(Value(LogicalType::STRING(),
node.getSingleEntry()->getLabel(catalog, transaction)));
}
children.push_back(node.getInternalID());
auto labelsValue =
Value(std::move(listType), populateLabelValues(context, node.getEntries()));
children.push_back(createLiteralExpression(labelsValue));
} break;
case LogicalTypeID::REL: {
auto& rel = expression.constCast<RelExpression>();
if (rel.isEmpty()) {
return createLiteralExpression("");
}
if (!rel.isMultiLabeled()) {
return createLiteralExpression(
Value(LogicalType::STRING(), rel.getSingleEntry()->getLabel(catalog, transaction)));
}
children.push_back(rel.getInternalIDProperty());
auto labelsValue =
Value(std::move(listType), populateLabelValues(context, rel.getEntries()));
children.push_back(createLiteralExpression(labelsValue));
} break;
default:
KU_UNREACHABLE;
}
auto function = std::make_unique<ScalarFunction>(LabelFunction::name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::INT64},
LogicalTypeID::STRING, LabelFunction::execFunction);
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType::STRING());
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(LabelFunction::name, children);
return std::make_shared<ScalarFunctionExpression>(ExpressionType::FUNCTION, std::move(function),
std::move(bindData), std::move(children), uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
1 change: 0 additions & 1 deletion src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ bool FunctionExpressionEvaluator::select(SelectionVector& selVector) {
// the correctness, because when all children are flat the check is done on tuple.
auto pos = resultVector->state->getSelVector()[0];
return resultVector->isNull(pos) ? 0 : resultVector->getValue<bool>(pos);
;
}
}
return function->selectFunc(parameters, selVector);
Expand Down
2 changes: 2 additions & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ FunctionCollection* FunctionCollection::getFunctions() {

// Node/rel functions
SCALAR_FUNCTION(OffsetFunction), REWRITE_FUNCTION(IDFunction),
REWRITE_FUNCTION(StartNodeFunction), REWRITE_FUNCTION(EndNodeFunction),
REWRITE_FUNCTION(LabelFunction),

// Path functions
SCALAR_FUNCTION(NodesFunction), SCALAR_FUNCTION(RelsFunction),
Expand Down
8 changes: 4 additions & 4 deletions src/function/path/length_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

static std::shared_ptr<Expression> rewriteFunc(const expression_vector& params,
ExpressionBinder* binder) {
KU_ASSERT(params.size() == 1);
auto param = params[0].get();
static std::shared_ptr<Expression> rewriteFunc(const RewriteFunctionBindInput& input) {
KU_ASSERT(input.arguments.size() == 1);
auto param = input.arguments[0].get();
auto binder = input.expressionBinder;
if (param->expressionType == ExpressionType::PATH) {
int64_t numRels = 0u;
std::vector<const RelExpression*> recursiveRels;
Expand Down
4 changes: 3 additions & 1 deletion src/function/pattern/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_library(kuzu_function_pattern
OBJECT
id_function.cpp)
id_function.cpp
label_function.cpp
start_end_node_function.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function_pattern>
Expand Down
15 changes: 6 additions & 9 deletions src/function/pattern/id_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "binder/expression/node_expression.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression_binder.h"
#include "common/types/value/value.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_node_rel_functions.h"
#include "function/struct/vector_struct_functions.h"
Expand All @@ -13,10 +12,9 @@ using namespace kuzu::binder;
namespace kuzu {
namespace function {

static std::shared_ptr<Expression> rewriteFunc(const expression_vector& params,
ExpressionBinder* binder) {
KU_ASSERT(params.size() == 1);
auto param = params[0].get();
static std::shared_ptr<Expression> rewriteFunc(const RewriteFunctionBindInput& input) {
KU_ASSERT(input.arguments.size() == 1);
auto param = input.arguments[0].get();
if (ExpressionUtil::isNodePattern(*param)) {
auto node = param->constPtrCast<NodeExpression>();
return node->getInternalID();
Expand All @@ -26,10 +24,9 @@ static std::shared_ptr<Expression> rewriteFunc(const expression_vector& params,
return rel->getPropertyExpression(InternalKeyword::ID);
}
// Bind as struct_extract(param, "_id")
auto keyExpr =
binder->createLiteralExpression(Value(LogicalType::STRING(), InternalKeyword::ID));
auto newParams = expression_vector{params[0], keyExpr};
return binder->bindScalarFunctionExpression(newParams, StructExtractFunctions::name);
auto extractKey = input.expressionBinder->createLiteralExpression(InternalKeyword::ID);
return input.expressionBinder->bindScalarFunctionExpression({input.arguments[0], extractKey},
StructExtractFunctions::name);
}

function_set IDFunction::getFunctionSet() {
Expand Down
Loading

0 comments on commit 69c6456

Please sign in to comment.