Skip to content

Commit 844450c

Browse files
authored
Merge pull request #31 from Maxxen/main
Fix lateral join optimizer for larger than vector size inputs, allow any table function for outer get
2 parents 5a6c41f + 9b53289 commit 844450c

File tree

2 files changed

+116
-5
lines changed

2 files changed

+116
-5
lines changed

src/hnsw/hnsw_optimize_join.cpp

+2-5
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ OperatorResultType PhysicalHNSWIndexJoin::Execute(ExecutionContext &context, Dat
137137
for (idx_t batch_idx = 0; batch_idx < batch_count; batch_idx++, state.input_idx++) {
138138

139139
// Get the next batch
140-
const auto rhs_vector_data = rhs_vector_ptr + batch_idx * rhs_vector_size;
140+
const auto rhs_vector_data = rhs_vector_ptr + state.input_idx * rhs_vector_size;
141141

142142
// Scan the index for row ids
143143
const auto match_count = hnsw_index.ExecuteMultiScan(*state.index_state, rhs_vector_data, limit);
144144
for (idx_t i = 0; i < match_count; i++) {
145-
state.match_sel.set_index(output_idx, batch_idx);
145+
state.match_sel.set_index(output_idx, state.input_idx);
146146
row_number_vector[output_idx] = i + 1; // Note: 1-indexed!
147147
output_idx++;
148148
}
@@ -337,9 +337,6 @@ bool HNSWIndexJoinOptimizer::TryOptimize(Binder &binder, ClientContext &context,
337337
MATCH_OPERATOR(delim_join.children[1], LOGICAL_GET, 0);
338338
auto outer_get_ptr = &delim_join.children[1];
339339
auto &outer_get = (*outer_get_ptr)->Cast<LogicalGet>();
340-
if (outer_get.function.name != "seq_scan") {
341-
return false;
342-
}
343340

344341
// branch
345342
// There might not be a projection here if we keep the distance function.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
require vss
2+
3+
statement ok
4+
SELECT setseed(0.1337);
5+
6+
statement ok
7+
CREATE TABLE queries (id INT, embedding FLOAT[3]);
8+
9+
statement ok
10+
INSERT INTO queries SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i);
11+
12+
statement ok
13+
CREATE TABLE items (id INT, embedding FLOAT[3]);
14+
15+
statement ok
16+
INSERT INTO items SELECT i, [random(), random(), random()]::FLOAT[3] FROM range(1, 1000) as r(i);
17+
18+
19+
# Sanity check, total cardinality
20+
query I
21+
SELECT COUNT(*)
22+
FROM queries, LATERAL (
23+
SELECT
24+
items.id as nbr,
25+
array_distance(items.embedding, queries.embedding) as dist
26+
FROM items
27+
ORDER BY dist
28+
LIMIT 3
29+
);
30+
----
31+
2997
32+
33+
query I rowsort result_total
34+
SELECT COUNT(*)
35+
FROM queries, LATERAL (
36+
SELECT
37+
items.id as nbr,
38+
array_distance(items.embedding, queries.embedding) as dist
39+
FROM items
40+
ORDER BY dist
41+
LIMIT 3
42+
);
43+
----
44+
45+
# Sanity check, groups of 3
46+
query I rowsort result_count
47+
SELECT count(*) FROM (
48+
SELECT queries.id as id, any_value(nbr)
49+
FROM queries, LATERAL (
50+
SELECT
51+
items.id as nbr,
52+
array_distance(items.embedding, queries.embedding) as dist
53+
FROM items
54+
ORDER BY dist
55+
LIMIT 3
56+
) GROUP BY queries.id
57+
)
58+
----
59+
60+
61+
query II rowsort result_scan
62+
SELECT queries.id as id, list(nbr ORDER BY nbr) result_scan
63+
FROM queries, LATERAL (
64+
SELECT
65+
items.id as nbr,
66+
array_distance(items.embedding, queries.embedding) as dist
67+
FROM items
68+
ORDER BY dist
69+
LIMIT 3
70+
) GROUP BY queries.id
71+
----
72+
73+
74+
# Now create an index
75+
statement ok
76+
CREATE INDEX items_embedding_idx ON items USING hnsw(embedding);
77+
78+
query I rowsort result_total
79+
SELECT COUNT(*)
80+
FROM queries, LATERAL (
81+
SELECT
82+
items.id as nbr,
83+
array_distance(items.embedding, queries.embedding) as dist
84+
FROM items
85+
ORDER BY dist
86+
LIMIT 3
87+
);
88+
----
89+
90+
query I rowsort result_count
91+
SELECT count(*) FROM (
92+
SELECT queries.id as id, any_value(nbr)
93+
FROM queries, LATERAL (
94+
SELECT
95+
items.id as nbr,
96+
array_distance(items.embedding, queries.embedding) as dist
97+
FROM items
98+
ORDER BY dist
99+
LIMIT 3
100+
) GROUP BY queries.id
101+
)
102+
----
103+
104+
query II rowsort result_scan
105+
SELECT queries.id as id, list(nbr ORDER BY nbr)
106+
FROM queries, LATERAL (
107+
SELECT
108+
items.id as nbr,
109+
array_distance(items.embedding, queries.embedding) as dist
110+
FROM items
111+
ORDER BY dist
112+
LIMIT 3
113+
) GROUP BY queries.id
114+
----

0 commit comments

Comments
 (0)