Skip to content

Commit

Permalink
Fix HNSW index shrink; add setting enable_internal_catalog (#4939)
Browse files Browse the repository at this point in the history
  • Loading branch information
ray6080 authored Feb 20, 2025
1 parent 5d4664a commit c031db9
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/binder/bind/bind_export_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static std::vector<ExportedTableData> getExportInfo(const Catalog& catalog,
main::ClientContext* context, Binder* binder) {
auto transaction = context->getTransaction();
std::vector<ExportedTableData> exportData;
for (auto tableEntry : catalog.getTableEntries(transaction)) {
for (auto tableEntry : catalog.getTableEntries(transaction, false /*useInternal*/)) {
ExportedTableData tableData;
if (binder->bindExportTableData(tableData, *tableEntry, catalog, transaction)) {
exportData.push_back(std::move(tableData));
Expand Down
8 changes: 7 additions & 1 deletion src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,17 @@ std::vector<RelTableCatalogEntry*> Catalog::getRelTableEntries(const Transaction
return result;
}

std::vector<TableCatalogEntry*> Catalog::getTableEntries(const Transaction* transaction) const {
std::vector<TableCatalogEntry*> Catalog::getTableEntries(const Transaction* transaction,
bool useInternal) const {
std::vector<TableCatalogEntry*> result;
for (auto& [_, entry] : tables->getEntries(transaction)) {
result.push_back(entry->ptrCast<TableCatalogEntry>());
}
if (useInternal) {
for (auto& [_, entry] : internalTables->getEntries(transaction)) {
result.push_back(entry->ptrCast<RelTableCatalogEntry>());
}
}
return result;
}

Expand Down
7 changes: 4 additions & 3 deletions src/function/table/show_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ static std::unique_ptr<TableFuncBindData> bindFunc(const main::ClientContext* co
auto transaction = context->getTransaction();
if (!context->hasDefaultDatabase()) {
auto catalog = context->getCatalog();
for (auto& entry : catalog->getTableEntries(transaction)) {
for (auto& entry :
catalog->getTableEntries(transaction, context->useInternalCatalogEntry())) {
if (entry->getType() == CatalogEntryType::REL_TABLE_ENTRY &&
entry->constCast<RelTableCatalogEntry>().hasParentRelGroup(catalog, transaction)) {
continue;
Expand All @@ -95,8 +96,8 @@ static std::unique_ptr<TableFuncBindData> bindFunc(const main::ClientContext* co
for (auto attachedDatabase : databaseManager->getAttachedDatabases()) {
auto databaseName = attachedDatabase->getDBName();
auto databaseType = attachedDatabase->getDBType();
for (auto& entry :
attachedDatabase->getCatalog()->getTableEntries(context->getTransaction())) {
for (auto& entry : attachedDatabase->getCatalog()->getTableEntries(
context->getTransaction(), context->useInternalCatalogEntry())) {
auto tableInfo = TableInfo{entry->getName(), entry->getTableID(),
TableTypeUtils::toString(entry->getTableType()),
stringFormat("{}({})", databaseName, databaseType), entry->getComment()};
Expand Down
4 changes: 2 additions & 2 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class KUZU_API Catalog {
std::vector<RelTableCatalogEntry*> getRelTableEntries(
const transaction::Transaction* transaction, bool useInternal = true) const;
// Get all table entries.
std::vector<TableCatalogEntry*> getTableEntries(
const transaction::Transaction* transaction) const;
std::vector<TableCatalogEntry*> getTableEntries(const transaction::Transaction* transaction,
bool useInternal = true) const;

// Create table catalog entry.
CatalogEntry* createTableEntry(transaction::Transaction* transaction,
Expand Down
2 changes: 2 additions & 0 deletions src/include/main/client_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct ClientConfigDefault {
static constexpr bool DISABLE_MAP_KEY_CHECK = true;
static constexpr uint64_t WARNING_LIMIT = 8 * 1024;
static constexpr bool ENABLE_PLAN_OPTIMIZER = true;
static constexpr bool ENABLE_INTERNAL_CATALOG = false;
};

struct ClientConfig {
Expand Down Expand Up @@ -53,6 +54,7 @@ struct ClientConfig {
uint64_t warningLimit = ClientConfigDefault::WARNING_LIMIT;
bool disableMapKeyCheck = ClientConfigDefault::DISABLE_MAP_KEY_CHECK;
bool enablePlanOptimizer = ClientConfigDefault::ENABLE_PLAN_OPTIMIZER;
bool enableInternalCatalog = ClientConfigDefault::ENABLE_INTERNAL_CATALOG;
};

} // namespace main
Expand Down
4 changes: 3 additions & 1 deletion src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ class KUZU_API ClientContext {
void setUseInternalCatalogEntry(bool useInternalCatalogEntry) {
this->useInternalCatalogEntry_ = useInternalCatalogEntry;
}
bool useInternalCatalogEntry() const { return useInternalCatalogEntry_; }
bool useInternalCatalogEntry() const {
return clientConfig.enableInternalCatalog ? true : useInternalCatalogEntry_;
}

void addScalarFunction(std::string name, function::function_set definitions);
void removeScalarFunction(const std::string& name);
Expand Down
12 changes: 12 additions & 0 deletions src/include/main/settings.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,17 @@ struct EnableOptimizerSetting {
}
};

struct EnableInternalCatalogSetting {
static constexpr auto name = "enable_internal_catalog";
static constexpr auto inputType = common::LogicalTypeID::BOOL;
static void setContext(ClientContext* context, const common::Value& parameter) {
parameter.validateType(inputType);
context->getClientConfigUnsafe()->enableInternalCatalog = parameter.getValue<bool>();
}
static common::Value getSetting(const ClientContext* context) {
return common::Value::createValue(context->getClientConfig()->enableInternalCatalog);
}
};

} // namespace main
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/main/db_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ static ConfigurationOption options[] = { // NOLINT(cert-err58-cpp):
GET_CONFIGURATION(RecursivePatternFactorSetting), GET_CONFIGURATION(EnableMVCCSetting),
GET_CONFIGURATION(CheckpointThresholdSetting), GET_CONFIGURATION(AutoCheckpointSetting),
GET_CONFIGURATION(ForceCheckpointClosingDBSetting), GET_CONFIGURATION(SpillToDiskSetting),
GET_CONFIGURATION(EnableGDSSetting), GET_CONFIGURATION(EnableOptimizerSetting)};
GET_CONFIGURATION(EnableGDSSetting), GET_CONFIGURATION(EnableOptimizerSetting),
GET_CONFIGURATION(EnableInternalCatalogSetting)};

DBConfig::DBConfig(const SystemConfig& systemConfig)
: bufferPoolSize{systemConfig.bufferPoolSize}, maxNumThreads{systemConfig.maxNumThreads},
Expand Down
2 changes: 1 addition & 1 deletion src/storage/index/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void InMemHNSWLayer::shrink(transaction::Transaction* transaction) {
for (auto i = 0u; i < info.numNodes; i++) {
const auto numNbrs = graph->getCSRLength(i);
if (numNbrs <= info.maxDegree) {
return;
continue;
}
shrinkForNode(transaction, info, graph.get(), i, numNbrs);
}
Expand Down
13 changes: 13 additions & 0 deletions test/test_files/function/hnsw/small.test
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@
-STATEMENT CALL show_tables() return *;
---- 1
0|embeddings|NODE|local(kuzu)|
-STATEMENT CALL ENABLE_INTERNAL_CATALOG=true;
---- ok
-STATEMENT CALL show_tables() return *;
---- 3
0|embeddings|NODE|local(kuzu)|
9223372036854775808|_e_hnsw_index_UPPER|REL|local(kuzu)|
9223372036854775809|_e_hnsw_index_LOWER|REL|local(kuzu)|
-STATEMENT MATCH (t1:embeddings)-[r:_e_hnsw_index_UPPER]->(t2:embeddings) HINT (t1 JOIN r) JOIN t2 WITH t1.id as id, count(*) as cnt RETURN max(cnt);
---- 1
30
-STATEMENT MATCH (t1:embeddings)-[r:_e_hnsw_index_LOWER]->(t2:embeddings) HINT (t1 JOIN r) JOIN t2 WITH t1.id as id, count(*) as cnt RETURN max(cnt);
---- 1
60
-STATEMENT CALL QUERY_HNSW_INDEX('e_hnsw_index', 'embeddings', CAST([0.0459,0.0439,0.0251,0.0318,0.0159,0.0564,0.0335,0.0273,0.0156,0.0627,0.0504,0.0116,0.0289,0.0432,0.0388,0.0282,0.0405,0.0417,0.0309,0.0338,0.0165,0.0291,0.0245,0.0208,0.0207,0.0727,0.0386,0.0145,0.0347,0.0462,0.0238,0.0333,0.0616,0.0418,0.0344,0.0448,0.0221,0.0348,0.0275,0.0319,0.0445,0.1036,0.0365,0.0168,0.0539,0.0554,0.0224,0.0432,0.1612,0.0764,0.0892,0.1059,0.0974,0.057,0.064,0.072,0.1108,0.1132,0.0399,0.035,0.0914,0.0654,0.0676,0.0481,0.2254,0.0976,0.1954,0.1424,0.1787,0.0932,0.0989,0.0909,0.1794,0.0632,0.104,0.0889,0.1972,0.0874,0.1693,0.1058,0.0982,0.0606,0.1163,0.0678,0.0728,0.065,0.029,0.0384,0.0692,0.0409,0.0631,0.0669,0.1183,0.0573,0.0792,0.1229,0.0317,0.0608,0.0381,0.0357,0.0282,0.0554,0.0146,0.0199,0.0332,0.0343,0.0357,0.0265,0.0308,0.0283,0.0285,0.0467,0.0312,0.0634,0.0273,0.0298,0.0189,0.0766,0.0253,0.0241,0.0218,0.0432,0.0398,0.0132,0.0252,0.0278,0.0262,0.0261,0.0442,0.0496,0.0578,0.0428,0.0215,0.077,0.0415,0.0274,0.0247,0.0802,0.0313,0.0157,0.0308,0.0564,0.0409,0.0171,0.0445,0.0431,0.0398,0.0338,0.015,0.0268,0.0264,0.0234,0.0342,0.0785,0.044,0.0111,0.0373,0.0732,0.0441,0.0267,0.062,0.0705,0.052,0.0588,0.0347,0.0398,0.0434,0.0626,0.0348,0.0771,0.0439,0.0277,0.0688,0.1063,0.0426,0.0315,0.1182,0.108,0.1233,0.1331,0.1333,0.0894,0.084,0.147,0.1676,0.0873,0.0449,0.0598,0.1058,0.1296,0.1083,0.0598,0.3077,0.171,0.2578,0.1343,0.3025,0.148,0.1543,0.115,0.3003,0.129,0.1667,0.1124,0.2835,0.1885,0.268,0.1088,0.1394,0.0846,0.1385,0.0723,0.1257,0.1025,0.0347,0.0402,0.096,0.0779,0.0715,0.1003,0.1399,0.1198,0.1459,0.1709,0.0286,0.0901,0.0579,0.0447,0.0259,0.072,0.027,0.0172,0.0401,0.0367,0.0308,0.0446,0.0333,0.0428,0.0485,0.0842,0.0423,0.0993,0.0507,0.0226,0.0334,0.1118,0.0455,0.0215,0.034,0.0626,0.0287,0.0151,0.0331,0.0358,0.0273,0.0148,0.0481,0.0572,0.1005,0.041,0.0327,0.039,0.0301,0.0316,0.0348,0.0455,0.0162,0.0291,0.0381,0.0462,0.0711,0.0388,0.1661,0.0992,0.131,0.1682,0.2028,0.1029,0.1002,0.1855,0.1282,0.0897,0.0436,0.0749,0.1315,0.1428,0.0925,0.0733,0.3956,0.2322,0.2294,0.3246,0.4046,0.2421,0.1451,0.292,0.3793,0.2134,0.1256,0.2624,0.3962,0.2613,0.2024,0.2949,0.1559,0.1304,0.1246,0.1057,0.1153,0.1169,0.0454,0.0656,0.1537,0.0689,0.0724,0.1378,0.1997,0.1151,0.1348,0.1879,0.0378,0.0295,0.0267,0.0348,0.0299,0.1118,0.0764,0.0252,0.0216,0.0718,0.0883,0.024,0.0255,0.0446,0.0302,0.0252,0.0384,0.0321,0.0286,0.0506,0.0365,0.0743,0.0723,0.0205,0.0279,0.1004,0.076,0.0191,0.0269,0.0528,0.0361,0.0294,0.0602,0.0421,0.0368,0.0597,0.0258,0.0539,0.0689,0.023,0.05,0.1801,0.0726,0.0224,0.0318,0.051,0.028,0.0313,0.1244,0.0621,0.0508,0.0942,0.077,0.0747,0.1127,0.0799,0.0912,0.1698,0.0911,0.05,0.1001,0.049,0.0347,0.0251,0.1328,0.0812,0.1235,0.0955,0.1485,0.0946,0.1559,0.0929,0.1354,0.0936,0.1809,0.1137,0.1718,0.0684,0.12,0.0801,0.0574,0.0528,0.0601,0.0403,0.0676,0.0763,0.0613,0.0456,0.0588,0.0613,0.1312,0.0859,0.0948,0.0499,0.052,0.0911,0.0325,0.0321,0.025,0.0217,0.0272,0.068,0.0372,0.0287,0.0296,0.0501,0.0738,0.0348,0.0293,0.0277,0.0336,0.0375,0.0297,0.0363,0.0249,0.0299,0.0266,0.0953,0.0633,0.0312,0.0307,0.0476,0.0782,0.0227,0.024,0.0336,0.0265,0.023,0.035,0.0365,0.0427,0.0366,0.0246,0.0864,0.0795,0.0218,0.0382,0.0744,0.0657,0.0357,0.0228,0.0661,0.0471,0.0146,0.0397,0.0349,0.0395,0.0464,0.0203,0.0447,0.0729,0.0205,0.0433,0.075,0.083,0.0212,0.04,0.0859,0.0581,0.0338,0.0453,0.0414,0.0623,0.0866,0.0264,0.0572,0.0933,0.0392,0.0314,0.109,0.092,0.0258,0.0393,0.0864,0.0511,0.037,0.0811,0.074,0.0903,0.1156,0.092,0.1179,0.1505,0.1201,0.1061,0.1405,0.1053,0.0583,0.0773,0.0845,0.0484,0.026,0.2203,0.1299,0.1846,0.0925,0.2102,0.1497,0.2352,0.0999,0.188,0.1712,0.2815,0.1449,0.1963,0.1609,0.1918,0.0808,0.0825,0.0618,0.0858,0.0531,0.0997,0.0804,0.0711,0.0411,0.0658,0.1118,0.1539,0.1242,0.096,0.1132,0.0893,0.122,0.0253,0.0549,0.0244,0.0252,0.0308,0.0559,0.0656,0.0236,0.0404,0.0322,0.0613,0.0598,0.0347,0.0375,0.039,0.0451,0.0408,0.0565,0.0381,0.0251,0.033,0.1151,0.089,0.0261,0.0415,0.0602,0.0762,0.0268,0.032,0.0312,0.0475,0.0221,0.0496,0.0325,0.0585,0.0332,0.0408,0.0616,0.0698,0.0395,0.0429,0.0651,0.082,0.0386,0.0405,0.0342,0.0579,0.0271,0.0939,0.0647,0.0884,0.0954,0.14,0.1124,0.1252,0.105,0.1049,0.1092,0.0498,0.0511,0.1082,0.1031,0.0545,0.0474,0.2339,0.1581,0.1452,0.1777,0.2803,0.2541,0.1426,0.1622,0.2905,0.2539,0.141,0.1648,0.2805,0.2083,0.1199,0.1783,0.0945,0.072,0.0748,0.0624,0.077,0.116,0.0494,0.0385,0.1258,0.0837,0.095,0.0853,0.1383,0.0923,0.0896,0.1217,0.0219,0.0325,0.0252,0.0218,0.0134,0.0766,0.057,0.0276,0.024,0.0582,0.0787,0.0289,0.021,0.0314,0.028,0.0213,0.0291,0.0395,0.0382,0.0296,0.0175,0.0512,0.0698,0.0267,0.0245,0.0752,0.0793,0.0232,0.0263,0.0335,0.0342,0.0296,0.0329,0.0374,0.029,0.0595,0.0235,0.043,0.0518,0.0456,0.0474,0.111,0.0492,0.0248,0.0317,0.0458,0.0297,0.0271,0.0739,0.0451,0.0458,0.0813,0.0789,0.0538,0.0909,0.0962,0.0869,0.1304,0.0801,0.0452,0.0858,0.0506,0.0496,0.0315,0.0812,0.0385,0.0875,0.0608,0.131,0.0597,0.1399,0.0882,0.1321,0.0864,0.1482,0.0969,0.1444,0.0704,0.0887,0.0552,0.0418,0.0448,0.0561,0.0491,0.0658,0.0617,0.057,0.0653,0.0605,0.0585,0.0923,0.0708,0.0723,0.0571,0.0414,0.0629,0.0273,0.0428,0.0371,0.0306,0.0345,0.0721,0.0302,0.0345,0.0224,0.052,0.0572,0.0428,0.0288,0.0317,0.0263,0.0414,0.0285,0.0348,0.0291,0.0282,0.0206,0.0705,0.049,0.0296,0.0214,0.0694,0.0656,0.0416,0.0306,0.0256,0.0202,0.0232,0.0167,0.0198,0.0256,0.0207,0.0189,0.0481,0.0428,0.0321,0.0217,0.0511,0.0531,0.0196,0.0206,0.0283,0.0257,0.0161,0.0169,0.0269,0.0302,0.0294,0.0173,0.0262,0.0316,0.0299,0.019,0.0491,0.0637,0.0251,0.0201,0.0426,0.0338,0.0233,0.0246,0.0225,0.0354,0.0744,0.0202,0.0372,0.049,0.039,0.0322,0.0636,0.0502,0.0283,0.0233,0.0521,0.0461,0.0285,0.0469,0.0615,0.0489,0.0993,0.0668,0.0872,0.1215,0.0866,0.0879,0.0996,0.0837,0.0378,0.055,0.0551,0.041,0.0295,0.1195,0.0654,0.1101,0.0597,0.1516,0.0881,0.1781,0.0614,0.1401,0.109,0.2157,0.0452,0.1375,0.085,0.1214,0.035,0.0435,0.0465,0.0585,0.0344,0.0746,0.0636,0.0506,0.0325,0.0465,0.1114,0.134,0.0451,0.0784,0.0642,0.0518,0.0704,0.0234,0.0486,0.038,0.0358,0.0266,0.0543,0.0312,0.0305,0.022,0.0611,0.0565,0.0329,0.0256,0.0419,0.0432,0.052,0.023,0.0339,0.0296,0.0324,0.0156,0.0747,0.0633,0.0197,0.0167,0.0449,0.058,0.0261,0.0237,0.0272,0.0337,0.0182,0.024,0.0259,0.0207,0.0165,0.0301,0.047,0.0439,0.0349,0.0306,0.0499,0.0622,0.0369,0.0214,0.0234,0.0291,0.0225,0.0373,0.0298,0.0386,0.0449,0.0865,0.0706,0.0773,0.0548,0.0537,0.0741,0.0339,0.0295,0.0573,0.0492,0.0204,0.0413,0.0796,0.0449,0.0652,0.0425,0.1487,0.1466,0.0773,0.0685,0.1442,0.1594,0.0879,0.0544,0.1113,0.0798,0.0579,0.0586,0.0448,0.0493,0.0358,0.0171,0.0489,0.0726,0.0282,0.0228,0.0583,0.0614,0.0752,0.0312,0.0539,0.0429,0.0458,0.0461],'FLOAT[960]'), 3) RETURN nn.id ORDER BY _distance;
-CHECK_ORDER
---- 3
Expand Down
3 changes: 2 additions & 1 deletion tools/shell/embedded_shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ DWORD oldOutputCP;
void EmbeddedShell::updateTableNames() {
nodeTableNames.clear();
relTableNames.clear();
for (auto& tableEntry : database->catalog->getTableEntries(&transaction::DUMMY_TRANSACTION)) {
for (auto& tableEntry : database->catalog->getTableEntries(&transaction::DUMMY_TRANSACTION,
false /*useInternal*/)) {
if (tableEntry->getType() == catalog::CatalogEntryType::NODE_TABLE_ENTRY) {
nodeTableNames.push_back(tableEntry->getName());
} else if (tableEntry->getType() == catalog::CatalogEntryType::REL_TABLE_ENTRY) {
Expand Down

0 comments on commit c031db9

Please sign in to comment.