From ed3d2de1f2a2027fbea92708b82805f9af2dcd87 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 14:19:30 +0000 Subject: [PATCH 01/17] parallelise duckdb --- splink/blocking.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/splink/blocking.py b/splink/blocking.py index 47160a499c..17ab02dfe2 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -238,7 +238,17 @@ def create_blocked_pairs_sql(self, linker: Linker, where_condition, probability) """ sqls.append(sql) - return " UNION ALL ".join(sqls) + + unioned_sql = " UNION ALL ".join(sqls) + + # see https://github.com/duckdb/duckdb/discussions/9710 + # this generates a huge speedup because it triggers parallelisation + if linker._sql_dialect == "duckdb": + unioned_sql = f""" + {unioned_sql} + order by 1 + """ + return unioned_sql def _sql_gen_where_condition(link_type, unique_id_cols): From 9208d747511538819c40ea3475280706bdb4a93b Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 14:45:03 +0000 Subject: [PATCH 02/17] order by 1 to hint parallelisation in best place --- splink/blocking.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/splink/blocking.py b/splink/blocking.py index 17ab02dfe2..29e25e1c24 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -239,16 +239,7 @@ def create_blocked_pairs_sql(self, linker: Linker, where_condition, probability) sqls.append(sql) - unioned_sql = " UNION ALL ".join(sqls) - - # see https://github.com/duckdb/duckdb/discussions/9710 - # this generates a huge speedup because it triggers parallelisation - if linker._sql_dialect == "duckdb": - unioned_sql = f""" - {unioned_sql} - order by 1 - """ - return unioned_sql + return " UNION ALL ".join(sqls) def _sql_gen_where_condition(link_type, unique_id_cols): @@ -371,8 +362,16 @@ def block_using_rules_sqls(linker: Linker): sql = br.create_blocked_pairs_sql(linker, where_condition, probability) br_sqls.append(sql) - sql = " UNION ALL ".join(br_sqls) + unioned_sql = " UNION ALL ".join(br_sqls) + + # see https://github.com/duckdb/duckdb/discussions/9710 + # this generates a huge speedup because it triggers parallelisation + if linker._sql_dialect == "duckdb": + unioned_sql = f""" + {unioned_sql} + order by 1 + """ - sqls.append({"sql": sql, "output_table_name": "__splink__df_blocked"}) + sqls.append({"sql": unioned_sql, "output_table_name": "__splink__df_blocked"}) return sqls From ef8ec923cfcbaa93f8788b550f8073f58b8e3327 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 14:56:52 +0000 Subject: [PATCH 03/17] fix convergence test --- tests/test_correctness_of_convergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_correctness_of_convergence.py b/tests/test_correctness_of_convergence.py index 4adcecec30..34f2e136c0 100644 --- a/tests/test_correctness_of_convergence.py +++ b/tests/test_correctness_of_convergence.py @@ -68,7 +68,7 @@ def test_splink_converges_to_known_params(): # CREATE TABLE __splink__df_comparison_vectors_abc123 # and modify the following line to include the value of the hash (abc123 above) - cvv_hashed_tablename = "__splink__df_comparison_vectors_f9bd31158" + cvv_hashed_tablename = "__splink__df_comparison_vectors_3f3fea0c5" linker.register_table(df, cvv_hashed_tablename) em_training_session = EMTrainingSession( From a75af8dd7981e440df948af6e1251bc34d6128b2 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 17:25:32 +0000 Subject: [PATCH 04/17] doens't seem to improve things --- splink/blocking.py | 2 +- splink/estimate_u.py | 8 ++++++-- splink/vertically_concatenate.py | 5 +++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/splink/blocking.py b/splink/blocking.py index 29e25e1c24..2171705bf9 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -366,7 +366,7 @@ def block_using_rules_sqls(linker: Linker): # see https://github.com/duckdb/duckdb/discussions/9710 # this generates a huge speedup because it triggers parallelisation - if linker._sql_dialect == "duckdb": + if linker._sql_dialect == "duckdb" and linker.__apply_sort: unioned_sql = f""" {unioned_sql} order by 1 diff --git a/splink/estimate_u.py b/splink/estimate_u.py index e019da46c7..68ad10b0f4 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, List -from .blocking import block_using_rules_sqls +from .blocking import block_using_rules_sqls, blocking_rule_to_obj from .comparison_vector_values import compute_comparison_vector_values_sql from .expectation_maximisation import ( compute_new_parameters_sql, @@ -117,7 +117,11 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) - settings_obj._blocking_rules_to_generate_predictions = [] + if linker._sql_dialect == "duckdb" and linker.__apply_sort: + br = blocking_rule_to_obj({"blocking_rule": "1=1", "salting_partitions": 2}) + settings_obj._blocking_rules_to_generate_predictions = [br] + else: + settings_obj._blocking_rules_to_generate_predictions = [] sqls = block_using_rules_sqls(training_linker) for sql in sqls: diff --git a/splink/vertically_concatenate.py b/splink/vertically_concatenate.py index 7d3f69afb1..b8020dce44 100644 --- a/splink/vertically_concatenate.py +++ b/splink/vertically_concatenate.py @@ -44,6 +44,11 @@ def vertically_concatenate_sql(linker: Linker) -> str: salting_reqiured = linker._settings_obj.salting_required + # see https://github.com/duckdb/duckdb/discussions/9710 + # in duckdb to parallelise we need salting + if linker._sql_dialect == "duckdb": + salting_reqiured = True + if salting_reqiured: salt_sql = ", random() as __splink_salt" else: From 32669fbc6390f4ea2b02393451f6517006ba5377 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 17:59:29 +0000 Subject: [PATCH 05/17] 4 partitions --- splink/estimate_u.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 68ad10b0f4..95505a4723 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -118,7 +118,7 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) if linker._sql_dialect == "duckdb" and linker.__apply_sort: - br = blocking_rule_to_obj({"blocking_rule": "1=1", "salting_partitions": 2}) + br = blocking_rule_to_obj({"blocking_rule": "1=1", "salting_partitions": 4}) settings_obj._blocking_rules_to_generate_predictions = [br] else: settings_obj._blocking_rules_to_generate_predictions = [] From 6aeb6594b3775bf96afb528de12b96c072ab8639 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 18:30:10 +0000 Subject: [PATCH 06/17] works --- splink/blocking.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/splink/blocking.py b/splink/blocking.py index 2171705bf9..2c60083db9 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -366,7 +366,11 @@ def block_using_rules_sqls(linker: Linker): # see https://github.com/duckdb/duckdb/discussions/9710 # this generates a huge speedup because it triggers parallelisation - if linker._sql_dialect == "duckdb" and linker.__apply_sort: + if ( + linker._sql_dialect == "duckdb" + and linker.__apply_sort + and not linker._train_u_using_random_sample_mode + ): unioned_sql = f""" {unioned_sql} order by 1 From 0c096fabd98638d36cf0fd3c5d0389aba15325a7 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 18:31:47 +0000 Subject: [PATCH 07/17] fix --- splink/blocking.py | 6 +----- splink/estimate_u.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/splink/blocking.py b/splink/blocking.py index 2c60083db9..3b2e703465 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -366,11 +366,7 @@ def block_using_rules_sqls(linker: Linker): # see https://github.com/duckdb/duckdb/discussions/9710 # this generates a huge speedup because it triggers parallelisation - if ( - linker._sql_dialect == "duckdb" - and linker.__apply_sort - and not linker._train_u_using_random_sample_mode - ): + if linker._sql_dialect == "duckdb" and not linker._train_u_using_random_sample_mode: unioned_sql = f""" {unioned_sql} order by 1 diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 95505a4723..8aaa3905f4 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -117,7 +117,7 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) - if linker._sql_dialect == "duckdb" and linker.__apply_sort: + if linker._sql_dialect == "duckdb": br = blocking_rule_to_obj({"blocking_rule": "1=1", "salting_partitions": 4}) settings_obj._blocking_rules_to_generate_predictions = [br] else: From 90f6c431b628240c5b2fedd29d54dc2036e7ceeb Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 19:10:10 +0000 Subject: [PATCH 08/17] scale salting by data --- splink/estimate_u.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 8aaa3905f4..fd195c7021 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -117,8 +117,20 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) - if linker._sql_dialect == "duckdb": - br = blocking_rule_to_obj({"blocking_rule": "1=1", "salting_partitions": 4}) + if linker._sql_dialect == "duckdb" and max_pairs > 1e5: + if max_pairs < 1e6: + salting_partitions = 2 + elif max_pairs < 1e7: + salting_partitions = 4 + elif max_pairs < 1e8: + salting_partitions = 10 + elif max_pairs < 1e9: + salting_partitions = 20 + else: + salting_partitions = 50 + br = blocking_rule_to_obj( + {"blocking_rule": "1=1", "salting_partitions": salting_partitions} + ) settings_obj._blocking_rules_to_generate_predictions = [br] else: settings_obj._blocking_rules_to_generate_predictions = [] From 231d00fcc25ee4c5c8c4dd04020ba97f8e90bf7f Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 12 Dec 2023 19:17:02 +0000 Subject: [PATCH 09/17] fix convergence test --- tests/test_correctness_of_convergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_correctness_of_convergence.py b/tests/test_correctness_of_convergence.py index 34f2e136c0..8388f2de59 100644 --- a/tests/test_correctness_of_convergence.py +++ b/tests/test_correctness_of_convergence.py @@ -68,7 +68,7 @@ def test_splink_converges_to_known_params(): # CREATE TABLE __splink__df_comparison_vectors_abc123 # and modify the following line to include the value of the hash (abc123 above) - cvv_hashed_tablename = "__splink__df_comparison_vectors_3f3fea0c5" + cvv_hashed_tablename = "__splink__df_comparison_vectors_17733aa10" linker.register_table(df, cvv_hashed_tablename) em_training_session = EMTrainingSession( From dec3078b5bd650c29397b874e8aa9b137564994e Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sun, 24 Dec 2023 14:10:12 +0000 Subject: [PATCH 10/17] Add duckdb salting based on max_pairs --- splink/estimate_u.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/splink/estimate_u.py b/splink/estimate_u.py index fd195c7021..849e77daa7 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import math from copy import deepcopy from typing import TYPE_CHECKING, List @@ -51,6 +52,12 @@ def _proportion_sample_size_link_only( return proportion, sample_size +def _get_duckdb_salting(max_pairs): + logged = math.log(max_pairs, 10) + logged = max(logged - 4, 0) + return math.ceil(2.5**logged) + + def estimate_u_values(linker: Linker, max_pairs, seed=None): logger.info("----- Estimating u probabilities using random sampling -----") @@ -118,18 +125,11 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) if linker._sql_dialect == "duckdb" and max_pairs > 1e5: - if max_pairs < 1e6: - salting_partitions = 2 - elif max_pairs < 1e7: - salting_partitions = 4 - elif max_pairs < 1e8: - salting_partitions = 10 - elif max_pairs < 1e9: - salting_partitions = 20 - else: - salting_partitions = 50 br = blocking_rule_to_obj( - {"blocking_rule": "1=1", "salting_partitions": salting_partitions} + { + "blocking_rule": "1=1", + "salting_partitions": _get_duckdb_salting(max_pairs), + } ) settings_obj._blocking_rules_to_generate_predictions = [br] else: From 1d8b64b5b18acf9eb25f1beafd0dcf5893793af4 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sun, 24 Dec 2023 14:20:06 +0000 Subject: [PATCH 11/17] Refactor _get_duckdb_salting to double the returned value --- splink/estimate_u.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 849e77daa7..eaa5d29bb4 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -55,7 +55,7 @@ def _proportion_sample_size_link_only( def _get_duckdb_salting(max_pairs): logged = math.log(max_pairs, 10) logged = max(logged - 4, 0) - return math.ceil(2.5**logged) + return math.ceil(2.5**logged) * 2 def estimate_u_values(linker: Linker, max_pairs, seed=None): From 0c467204fda4f0e6348caea18db3b01a375953f6 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 4 Jan 2024 09:12:54 +0000 Subject: [PATCH 12/17] revert change that doubled cpus. was only used for benchmarking --- splink/estimate_u.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/estimate_u.py b/splink/estimate_u.py index eaa5d29bb4..849e77daa7 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -55,7 +55,7 @@ def _proportion_sample_size_link_only( def _get_duckdb_salting(max_pairs): logged = math.log(max_pairs, 10) logged = max(logged - 4, 0) - return math.ceil(2.5**logged) * 2 + return math.ceil(2.5**logged) def estimate_u_values(linker: Linker, max_pairs, seed=None): From 0b95034968811b8bf122757a494518ed041b6b56 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 10 Jan 2024 13:57:11 +0000 Subject: [PATCH 13/17] Refactor blocking and prediction SQL queries --- splink/blocking.py | 12 ++---------- splink/predict.py | 5 +++++ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/splink/blocking.py b/splink/blocking.py index 3b2e703465..b242100575 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -362,16 +362,8 @@ def block_using_rules_sqls(linker: Linker): sql = br.create_blocked_pairs_sql(linker, where_condition, probability) br_sqls.append(sql) - unioned_sql = " UNION ALL ".join(br_sqls) - - # see https://github.com/duckdb/duckdb/discussions/9710 - # this generates a huge speedup because it triggers parallelisation - if linker._sql_dialect == "duckdb" and not linker._train_u_using_random_sample_mode: - unioned_sql = f""" - {unioned_sql} - order by 1 - """ + sql = " UNION ALL ".join(br_sqls) - sqls.append({"sql": unioned_sql, "output_table_name": "__splink__df_blocked"}) + sqls.append({"sql": sql, "output_table_name": "__splink__df_blocked"}) return sqls diff --git a/splink/predict.py b/splink/predict.py index 95f0404176..c1b1b674f5 100644 --- a/splink/predict.py +++ b/splink/predict.py @@ -65,6 +65,10 @@ def predict_from_comparison_vectors_sqls( else: threshold_expr = "" + if settings_obj._sql_dialect == "duckdb": + order_by_statement = "order by 1" + else: + order_by_statement = "" sql = f""" select log2({bayes_factor_expr}) as match_weight, @@ -72,6 +76,7 @@ def predict_from_comparison_vectors_sqls( {select_cols_expr} {clerical_match_score} from __splink__df_match_weight_parts {threshold_expr} + {order_by_statement} """ sql = { From fc683ed1da664303fbe2d5127d1ed6bef869c596 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 10 Jan 2024 13:58:01 +0000 Subject: [PATCH 14/17] Remove unnecessary blank line in SaltedBlockingRule class --- splink/blocking.py | 1 - 1 file changed, 1 deletion(-) diff --git a/splink/blocking.py b/splink/blocking.py index b242100575..47160a499c 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -238,7 +238,6 @@ def create_blocked_pairs_sql(self, linker: Linker, where_condition, probability) """ sqls.append(sql) - return " UNION ALL ".join(sqls) From 0bca60283bb4f274b32939368292b8f4dffa4b4d Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 10 Jan 2024 13:59:22 +0000 Subject: [PATCH 15/17] Update estimate_u.py: Import multiprocessing and remove unused function --- splink/estimate_u.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 849e77daa7..a62bac6266 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -import math +import multiprocessing from copy import deepcopy from typing import TYPE_CHECKING, List @@ -52,12 +52,6 @@ def _proportion_sample_size_link_only( return proportion, sample_size -def _get_duckdb_salting(max_pairs): - logged = math.log(max_pairs, 10) - logged = max(logged - 4, 0) - return math.ceil(2.5**logged) - - def estimate_u_values(linker: Linker, max_pairs, seed=None): logger.info("----- Estimating u probabilities using random sampling -----") @@ -124,11 +118,11 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) - if linker._sql_dialect == "duckdb" and max_pairs > 1e5: + if linker._sql_dialect == "duckdb" and max_pairs > 1e4: br = blocking_rule_to_obj( { "blocking_rule": "1=1", - "salting_partitions": _get_duckdb_salting(max_pairs), + "salting_partitions": multiprocessing.cpu_count(), } ) settings_obj._blocking_rules_to_generate_predictions = [br] From 7168154edd757ebe67a9b43d913585412d1c6fd2 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 10 Jan 2024 14:07:40 +0000 Subject: [PATCH 16/17] Update cvv_hashed_tablename in test_correctness_of_convergence.py --- tests/test_correctness_of_convergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_correctness_of_convergence.py b/tests/test_correctness_of_convergence.py index 8388f2de59..8a45dbdbbe 100644 --- a/tests/test_correctness_of_convergence.py +++ b/tests/test_correctness_of_convergence.py @@ -68,7 +68,7 @@ def test_splink_converges_to_known_params(): # CREATE TABLE __splink__df_comparison_vectors_abc123 # and modify the following line to include the value of the hash (abc123 above) - cvv_hashed_tablename = "__splink__df_comparison_vectors_17733aa10" + cvv_hashed_tablename = "__splink__df_comparison_vectors_0247b2f3d" linker.register_table(df, cvv_hashed_tablename) em_training_session = EMTrainingSession( From cc3ac6803eecdd6ac07232f584f72d52b0aee917 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 10 Jan 2024 14:18:39 +0000 Subject: [PATCH 17/17] Update changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f792ceef2..509107691f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Splink now fully parallelises data linkage when using DuckDB ([#1796](https://github.com/moj-analytical-services/splink/pull/1796)) + ### Fixed ## [3.9.10] - 2023-12-07 @@ -20,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed issue with `_source_dataset_col` and `_source_dataset_input_column` ([#1731](https://github.com/moj-analytical-services/splink/pull/1731)) +- Fixed issue with `_source_dataset_col` and `_source_dataset_input_column` ([#1731](https://github.com/moj-analytical-services/splink/pull/1731)) - Delete cached tables before resetting the cache ([#1752](https://github.com/moj-analytical-services/splink/pull/1752) ## [3.9.9] - 2023-11-14