Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelise duckdb resulting in e.g. 2-4x speedup on 6 core machine #1796

Merged
merged 19 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def create_blocked_pairs_sql(self, linker: Linker, where_condition, probability)
"""

sqls.append(sql)

return " UNION ALL ".join(sqls)


Expand Down Expand Up @@ -361,8 +362,16 @@ def block_using_rules_sqls(linker: Linker):
sql = br.create_blocked_pairs_sql(linker, where_condition, probability)
br_sqls.append(sql)

sql = " UNION ALL ".join(br_sqls)
unioned_sql = " UNION ALL ".join(br_sqls)

# see https://github.com/duckdb/duckdb/discussions/9710
# this generates a huge speedup because it triggers parallelisation
if linker._sql_dialect == "duckdb" and not linker._train_u_using_random_sample_mode:
unioned_sql = f"""
{unioned_sql}
order by 1
"""

sqls.append({"sql": sql, "output_table_name": "__splink__df_blocked"})
sqls.append({"sql": unioned_sql, "output_table_name": "__splink__df_blocked"})

return sqls
20 changes: 18 additions & 2 deletions splink/estimate_u.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import logging
import math
from copy import deepcopy
from typing import TYPE_CHECKING, List

from .blocking import block_using_rules_sqls
from .blocking import block_using_rules_sqls, blocking_rule_to_obj
from .comparison_vector_values import compute_comparison_vector_values_sql
from .expectation_maximisation import (
compute_new_parameters_sql,
Expand Down Expand Up @@ -51,6 +52,12 @@ def _proportion_sample_size_link_only(
return proportion, sample_size


def _get_duckdb_salting(max_pairs):
logged = math.log(max_pairs, 10)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a heuristic.

max_pairs, salting
1e+00 1
1e+01 1
1e+02 1
1e+03 1
1e+04 1
1e+05 3
1e+06 7
1e+07 16
1e+08 40
2e+08 52
1e+09 98
1e+10 245
1e+11 611

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful to have a very brief explanation as a comment, as otherwise this function is maybe a little cryptic

logged = max(logged - 4, 0)
return math.ceil(2.5**logged)


def estimate_u_values(linker: Linker, max_pairs, seed=None):
logger.info("----- Estimating u probabilities using random sampling -----")

Expand Down Expand Up @@ -117,7 +124,16 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None):
training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample")
df_sample = training_linker._execute_sql_pipeline([nodes_with_tf])

settings_obj._blocking_rules_to_generate_predictions = []
if linker._sql_dialect == "duckdb" and max_pairs > 1e5:
br = blocking_rule_to_obj(
{
"blocking_rule": "1=1",
"salting_partitions": _get_duckdb_salting(max_pairs),
}
)
settings_obj._blocking_rules_to_generate_predictions = [br]
else:
settings_obj._blocking_rules_to_generate_predictions = []

sqls = block_using_rules_sqls(training_linker)
for sql in sqls:
Expand Down
5 changes: 5 additions & 0 deletions splink/vertically_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def vertically_concatenate_sql(linker: Linker) -> str:

salting_reqiured = linker._settings_obj.salting_required

# see https://github.com/duckdb/duckdb/discussions/9710
# in duckdb to parallelise we need salting
if linker._sql_dialect == "duckdb":
salting_reqiured = True

if salting_reqiured:
salt_sql = ", random() as __splink_salt"
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_correctness_of_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_splink_converges_to_known_params():
# CREATE TABLE __splink__df_comparison_vectors_abc123
# and modify the following line to include the value of the hash (abc123 above)

cvv_hashed_tablename = "__splink__df_comparison_vectors_f9bd31158"
cvv_hashed_tablename = "__splink__df_comparison_vectors_17733aa10"
linker.register_table(df, cvv_hashed_tablename)

em_training_session = EMTrainingSession(
Expand Down
Loading