Skip to content

Commit

Permalink
Remove recursive extend binding
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Feb 13, 2025
1 parent b3e085f commit 669570e
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 431 deletions.
8 changes: 4 additions & 4 deletions src/function/gds/all_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ class AllSPPathsEdgeCompute : public SPEdgeCompute {
* is returned for each destination. If paths are not returned, multiplicities indicate the
* number of paths to each destination.
*/
class AllSPDestinationsAlgorithm final : public SPAlgorithm {
class AllSPDestinationsAlgorithm final : public RJAlgorithm {

Check warning on line 219 in src/function/gds/all_shortest_paths.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/gds/all_shortest_paths.cpp#L219

Added line #L219 was not covered by tests
public:
AllSPDestinationsAlgorithm() = default;
AllSPDestinationsAlgorithm(const AllSPDestinationsAlgorithm& other) : SPAlgorithm{other} {}
AllSPDestinationsAlgorithm(const AllSPDestinationsAlgorithm& other) : RJAlgorithm{other} {}

expression_vector getResultColumns(const function::GDSBindInput& /*bindInput*/) const override {
auto columns = getBaseResultColumns();
Expand Down Expand Up @@ -250,10 +250,10 @@ class AllSPDestinationsAlgorithm final : public SPAlgorithm {
}
};

class AllSPPathsAlgorithm final : public SPAlgorithm {
class AllSPPathsAlgorithm final : public RJAlgorithm {

Check warning on line 253 in src/function/gds/all_shortest_paths.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/gds/all_shortest_paths.cpp#L253

Added line #L253 was not covered by tests
public:
AllSPPathsAlgorithm() = default;
AllSPPathsAlgorithm(const AllSPPathsAlgorithm& other) : SPAlgorithm{other} {}
AllSPPathsAlgorithm(const AllSPPathsAlgorithm& other) : RJAlgorithm{other} {}

expression_vector getResultColumns(const function::GDSBindInput& /*bindInput*/) const override {
auto columns = getBaseResultColumns();
Expand Down
70 changes: 6 additions & 64 deletions src/function/gds/rec_joins.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#include "function/gds/rec_joins.h"

#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/property_expression.h"
#include "common/enums/extend_direction_util.h"
#include "common/exception/binder.h"
#include "common/exception/interrupt.h"
#include "common/exception/runtime.h"
#include "common/task_system/progress_bar.h"
#include "function/gds/gds.h"
#include "function/gds/gds_utils.h"
Expand Down Expand Up @@ -47,29 +45,13 @@ PathsOutputWriterInfo RJBindData::getPathWriterInfo() const {
return info;
}

void RJAlgorithm::setToNoPath() {
bindData->ptrCast<RJBindData>()->writePath = false;
void RJAlgorithm::bind(const kuzu::function::GDSBindInput&, main::ClientContext&) {
throw common::BinderException("Recursive join should not be triggered through function calls. "
"Try cypher patter ()-[*]->() instead.");
}

void RJAlgorithm::validateLowerUpperBound(int64_t lowerBound, int64_t upperBound) {
if (lowerBound < 0 || upperBound < 0) {
throw RuntimeException(
stringFormat("Lower and upper bound lengths of recursive join operations need to be "
"non-negative. Given lower bound is: {} and upper bound is: {}.",
lowerBound, upperBound));
}
if (lowerBound > upperBound) {
throw RuntimeException(
stringFormat("Lower bound length of recursive join operations need to be less than or "
"equal to upper bound. Given lower bound is: {} and upper bound is: {}.",
lowerBound, upperBound));
}
if (upperBound >= RJBindData::DEFAULT_MAXIMUM_ALLOWED_UPPER_BOUND) {
throw RuntimeException(
stringFormat("Recursive join operations only works for non-positive upper bound "
"iterations that are up to {}. Given upper bound is: {}.",
RJBindData::DEFAULT_MAXIMUM_ALLOWED_UPPER_BOUND, upperBound));
}
void RJAlgorithm::setToNoPath() {
bindData->ptrCast<RJBindData>()->writePath = false;
}

binder::expression_vector RJAlgorithm::getResultColumnsNoPath() {
Expand All @@ -95,46 +77,6 @@ expression_vector RJAlgorithm::getBaseResultColumns() const {
return columns;
}

void RJAlgorithm::bindColumnExpressions(binder::Binder* binder) const {
auto rjBindData = bindData->ptrCast<RJBindData>();
if (rjBindData->extendDirection == common::ExtendDirection::BOTH) {
rjBindData->directionExpr =
binder->createVariable(DIRECTION_COLUMN_NAME, LogicalType::LIST(LogicalType::BOOL()));
}
rjBindData->lengthExpr = binder->createVariable(LENGTH_COLUMN_NAME, LogicalType::UINT16());
rjBindData->pathNodeIDsExpr = binder->createVariable(PATH_NODE_IDS_COLUMN_NAME,
LogicalType::LIST(LogicalType::INTERNAL_ID()));
rjBindData->pathEdgeIDsExpr = binder->createVariable(PATH_EDGE_IDS_COLUMN_NAME,
LogicalType::LIST(LogicalType::INTERNAL_ID()));
}

static void validateSPUpperBound(int64_t upperBound) {
if (upperBound == 0) {
throw RuntimeException(stringFormat("Shortest path operations only works for positive "
"upper bound iterations. Given upper bound is: {}.",
upperBound));
}
}

void SPAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) {
KU_ASSERT(input.getNumParams() == 4);
auto graphName = ExpressionUtil::getLiteralValue<std::string>(*input.getParam(0));
auto graphEntry = bindGraphEntry(context, graphName);
auto nodeOutput = bindNodeOutput(input, graphEntry.nodeEntries);
auto rjBindData = std::make_unique<RJBindData>(std::move(graphEntry), nodeOutput);
rjBindData->nodeInput = input.getParam(1);
rjBindData->lowerBound = 1;
auto upperBound = ExpressionUtil::getLiteralValue<int64_t>(*input.getParam(2));
validateSPUpperBound(upperBound);
validateLowerUpperBound(rjBindData->lowerBound, upperBound);
rjBindData->upperBound = upperBound;
rjBindData->semantic = PathSemantic::WALK;
rjBindData->extendDirection = ExtendDirectionUtil::fromString(
ExpressionUtil::getLiteralValue<std::string>(*input.getParam(3)));
bindData = std::move(rjBindData);
bindColumnExpressions(input.binder);
}

// All recursive join computation have the same vertex compute. This vertex compute writes
// result (could be dst, length or path) from a dst node ID to given source node ID.
class RJVertexCompute : public VertexCompute {
Expand Down
8 changes: 4 additions & 4 deletions src/function/gds/single_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ class SingleSPPathsEdgeCompute : public SPEdgeCompute {
* multiplicities of each destination is ignored (e.g., if there are 3 paths to a destination d,
* d is returned only once).
*/
class SingleSPDestinationsAlgorithm : public SPAlgorithm {
class SingleSPDestinationsAlgorithm : public RJAlgorithm {
public:
SingleSPDestinationsAlgorithm() = default;
SingleSPDestinationsAlgorithm(const SingleSPDestinationsAlgorithm& other)
: SPAlgorithm{other} {}
: RJAlgorithm{other} {}

expression_vector getResultColumns(const function::GDSBindInput& /*bindInput*/) const override {
auto columns = getBaseResultColumns();
Expand All @@ -149,10 +149,10 @@ class SingleSPDestinationsAlgorithm : public SPAlgorithm {
}
};

class SingleSPPathsAlgorithm : public SPAlgorithm {
class SingleSPPathsAlgorithm : public RJAlgorithm {
public:
SingleSPPathsAlgorithm() = default;
SingleSPPathsAlgorithm(const SingleSPPathsAlgorithm& other) : SPAlgorithm{other} {}
SingleSPPathsAlgorithm(const SingleSPPathsAlgorithm& other) : RJAlgorithm{other} {}

expression_vector getResultColumns(const function::GDSBindInput& /*bindInput*/) const override {
auto columns = getBaseResultColumns();
Expand Down
26 changes: 0 additions & 26 deletions src/function/gds/variable_length_path.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include <vector>

#include "binder/expression/expression_util.h"
#include "common/enums/extend_direction_util.h"
#include "common/types/types.h"
#include "function/gds/auxiliary_state/path_auxiliary_state.h"
#include "function/gds/gds_function_collection.h"
Expand Down Expand Up @@ -84,30 +82,6 @@ class VarLenJoinsAlgorithm final : public RJAlgorithm {
VarLenJoinsAlgorithm() = default;
VarLenJoinsAlgorithm(const VarLenJoinsAlgorithm& other) : RJAlgorithm(other) {}

// Inputs are: graph, srcNode, lowerBound, upperBound, direction
std::vector<LogicalTypeID> getParameterTypeIDs() const override {
return {LogicalTypeID::ANY, LogicalTypeID::NODE, LogicalTypeID::INT64, LogicalTypeID::INT64,
LogicalTypeID::STRING};
}

void bind(const GDSBindInput& input, main::ClientContext& context) override {
auto graphName = ExpressionUtil::getLiteralValue<std::string>(*input.getParam(0));
auto graphEntry = bindGraphEntry(context, graphName);
auto nodeOutput = bindNodeOutput(input, graphEntry.nodeEntries);
auto rjBindData = std::make_unique<RJBindData>(std::move(graphEntry), nodeOutput);
rjBindData->nodeInput = input.getParam(1);
auto lowerBound = ExpressionUtil::getLiteralValue<int64_t>(*input.getParam(2));
auto upperBound = ExpressionUtil::getLiteralValue<int64_t>(*input.getParam(3));
validateLowerUpperBound(lowerBound, upperBound);
rjBindData->lowerBound = lowerBound;
rjBindData->upperBound = upperBound;
rjBindData->semantic = PathSemantic::WALK;
rjBindData->extendDirection = ExtendDirectionUtil::fromString(
ExpressionUtil::getLiteralValue<std::string>(*input.getParam(4)));
bindData = std::move(rjBindData);
bindColumnExpressions(input.binder);
}

binder::expression_vector getResultColumns(
const function::GDSBindInput& /*bindInput*/) const override {
auto columns = getBaseResultColumns();
Expand Down
4 changes: 2 additions & 2 deletions src/function/gds/weighted_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ class DestinationsOutputWriter : public RJOutputWriter {
std::unique_ptr<ValueVector> costVector;
};

class WeightedSPDestinationsAlgorithm : public SPAlgorithm {
class WeightedSPDestinationsAlgorithm : public RJAlgorithm {
public:
WeightedSPDestinationsAlgorithm() = default;
WeightedSPDestinationsAlgorithm(const WeightedSPDestinationsAlgorithm& other)
: SPAlgorithm{other} {}
: RJAlgorithm{other} {}

binder::expression_vector getResultColumns(
const function::GDSBindInput& /*bindInput*/) const override {
Expand Down
19 changes: 2 additions & 17 deletions src/include/function/gds/rec_joins.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class RJAlgorithm : public GDSAlgorithm {
RJAlgorithm() = default;
RJAlgorithm(const RJAlgorithm& other) : GDSAlgorithm{other} {}

void bind(const kuzu::function::GDSBindInput& input, main::ClientContext& context) override;

void exec(processor::ExecutionContext* context) override;

virtual RJCompState getRJCompState(processor::ExecutionContext* context,
Expand All @@ -77,27 +79,10 @@ class RJAlgorithm : public GDSAlgorithm {
binder::expression_vector getResultColumnsNoPath();

protected:
void validateLowerUpperBound(int64_t lowerBound, int64_t upperBound);

binder::expression_vector getBaseResultColumns() const;
void bindColumnExpressions(binder::Binder* binder) const;

std::unique_ptr<BFSGraph> getBFSGraph(processor::ExecutionContext* context);
};

class SPAlgorithm : public RJAlgorithm {
public:
SPAlgorithm() = default;
SPAlgorithm(const SPAlgorithm& other) : RJAlgorithm{other} {}

// Inputs are graph, srcNode, upperBound, direction
std::vector<common::LogicalTypeID> getParameterTypeIDs() const override {
return {common::LogicalTypeID::ANY, common::LogicalTypeID::NODE,
common::LogicalTypeID::INT64, common::LogicalTypeID::STRING};
}

void bind(const GDSBindInput& input, main::ClientContext&) override;
};

} // namespace function
} // namespace kuzu
35 changes: 9 additions & 26 deletions test/test_files/function/gds/basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ Runtime exception: Project graph PK already exists.
Binder exception: Expect catalog entry type NODE_TABLE_ENTRY but got REL_TABLE_ENTRY.
-STATEMENT CALL create_project_graph('PKWO', ['person', 'organisation'], ['knows', 'workAt'])
---- ok
-STATEMENT MATCH (a:person) WHERE a.ID < 6
CALL VAR_LEN_JOINS('PK', a, 1, 2, "FWD")
RETURN a.fName, COUNT(*);
-STATEMENT MATCH (a:person)-[:knows*1..2]->(b:person) WHERE a.ID < 6 RETURN a.fName, COUNT(*);
---- 4
Alice|12
Bob|12
Expand All @@ -28,42 +26,27 @@ Runtime exception: Project graph dummy does not exists.
---- ok
-STATEMENT CALL create_project_graph('PK', ['person'], ['knows'])
---- ok
-STATEMENT MATCH (a:person) WHERE a.ID < 6
CALL VAR_LEN_JOINS('PK', a, 1, 2, "BWD")
RETURN a.fName, COUNT(*);
-STATEMENT CALL VAR_LEN_JOINS() RETURN *
---- error
Binder exception: Recursive join should not be triggered through function calls. Try cypher patter ()-[*]->() instead.
-STATEMENT MATCH (a:person)<-[:knows*1..2]-(b:person) WHERE a.ID < 6 RETURN a.fName, COUNT(*);
---- 4
Alice|12
Bob|12
Carol|12
Dan|12
-STATEMENT MATCH (a:person) WHERE a.ID < 6
CALL VAR_LEN_JOINS('PK', a, 1, 2, "BOTH")
RETURN a.fName, COUNT(*);
-STATEMENT MATCH (a:person)-[:knows*1..2]-(b:person) WHERE a.ID < 6 RETURN a.fName, COUNT(*);
---- 4
Alice|42
Bob|42
Carol|42
Dan|42
-STATEMENT MATCH (a:person) WHERE a.ID = 0
CALL SINGLE_SP_DESTINATIONS('PK', a, 2, "FWD")
RETURN a.fName, node.name, length;
---- error
Binder exception: Cannot find property name for node.
-STATEMENT MATCH (a:person) WHERE a.ID = 0
CALL SINGLE_SP_DESTINATIONS('PK', a, 2, "X")
RETURN a.fName, node.name, length;
---- error
Runtime exception: Cannot parse X as ExtendDirection.
-STATEMENT MATCH (a:person) WHERE a.ID = 0
CALL SINGLE_SP_DESTINATIONS('PK', a, 2, "FWD")
RETURN a.fName, node.fName, length;
-STATEMENT MATCH (a:person)-[e:knows* SHORTEST 1..2]->(b:person) WHERE a.ID = 0 RETURN a.fName, b.fName, length(e);
---- 3
Alice|Bob|1
Alice|Carol|1
Alice|Dan|1
-STATEMENT MATCH (a:person) WHERE a.ID = 0
CALL SINGLE_SP_DESTINATIONS('PKWO', a, 2, "FWD")
RETURN a.fName, node.fName, node.name, length;
-STATEMENT MATCH (a:person)-[e:knows|:workAt* SHORTEST 1..2]->(b:person:organisation) WHERE a.ID = 0 RETURN a.fName, b.fName, b.name, length(e);
---- 5
Alice|Bob||1
Alice|Carol||1
Expand Down Expand Up @@ -144,4 +127,4 @@ Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|0
[Alice,Dan,Alice]|[2021-06-30,2021-06-30]|[0:0,0:3]|[0:3,0:0]|Alice|Alice
[Alice,Dan,Bob]|[2021-06-30,1950-05-14]|[0:0,0:3]|[0:3,0:1]|Alice|Bob
[Alice,Dan,Carol]|[2021-06-30,2000-01-01]|[0:0,0:3]|[0:3,0:2]|Alice|Carol
[Alice,Dan]|[2021-06-30]|[0:0]|[0:3]|Alice|Dan
[Alice,Dan]|[2021-06-30]|[0:0]|[0:3]|Alice|Dan
Loading

0 comments on commit 669570e

Please sign in to comment.