From bff6f57a0dab566d1b0d320a94be618fc51e6eea Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Wed, 5 Feb 2025 18:25:07 +0000 Subject: [PATCH 01/10] make multiple threshold cluster public and add to the linker --- splink/clustering.py | 4 +- .../internals/linker_components/clustering.py | 289 ++++++++++++++++++ 2 files changed, 291 insertions(+), 2 deletions(-) diff --git a/splink/clustering.py b/splink/clustering.py index ea22d068d2..5a45817a50 100644 --- a/splink/clustering.py +++ b/splink/clustering.py @@ -1,3 +1,3 @@ -from .internals.clustering import cluster_pairwise_predictions_at_threshold +from .internals.clustering import cluster_pairwise_predictions_at_threshold, cluster_pairwise_predictions_at_multiple_thresholds -__all__ = ["cluster_pairwise_predictions_at_threshold"] +__all__ = ["cluster_pairwise_predictions_at_threshold", "cluster_pairwise_predictions_at_multiple_thresholds"] diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index 2e0f8aed22..da1302fae0 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -1,7 +1,16 @@ from __future__ import annotations +import logging + from typing import TYPE_CHECKING, Optional +from splink.internals.clustering import ( + cluster_pairwise_predictions_at_threshold, + _get_cluster_stats_sql, + _calculate_stable_clusters_at_new_threshold, + _generate_detailed_cluster_comparison_sql, + _generate_cluster_summary_stats_sql +) from splink.internals.connected_components import ( solve_connected_components, ) @@ -13,6 +22,7 @@ ) from splink.internals.misc import ( threshold_args_to_match_prob, + threshold_args_to_match_prob_list ) from splink.internals.pipeline import CTEPipeline from splink.internals.splink_dataframe import SplinkDataFrame @@ -177,6 +187,285 @@ def cluster_pairwise_predictions_at_threshold( return df_clustered_with_input_data + def cluster_pairwise_predictions_at_multiple_thresholds( + self, + df_predict: SplinkDataFrame, + match_probability_thresholds: Optional[list[float]] | None = None, + match_weight_thresholds: Optional[list[float]] | None = None, + output_cluster_summary_stats: bool = False, + ) -> SplinkDataFrame: + """Clusters the pairwise match predictions at multiple thresholds using + the connected components graph clustering algorithm. + + This function efficiently computes clusters for multiple thresholds by starting + with the lowest threshold and incrementally updating clusters for higher thresholds. + + If your node and edge column names follow Splink naming conventions, then you can + omit edge_id_column_name_left and edge_id_column_name_right. For example, if you + have a table of nodes with a column `unique_id`, it would be assumed that the + edge table has columns `unique_id_l` and `unique_id_r`. + + Args: + nodes (AcceptableInputTableType): The table containing node information + edges (AcceptableInputTableType): The table containing edge information + db_api (DatabaseAPISubClass): The database API to use for querying + node_id_column_name (str): The name of the column containing node IDs + match_probability_thresholds (list[float] | None): List of match probability + thresholds to compute clusters for + match_weight_thresholds (list[float] | None): List of match weight thresholds + to compute clusters for + edge_id_column_name_left (Optional[str]): The name of the column containing + left edge IDs. If not provided, assumed to be f"{node_id_column_name}_l" + edge_id_column_name_right (Optional[str]): The name of the column containing + right edge IDs. If not provided, assumed to be f"{node_id_column_name}_r" + output_cluster_summary_stats (bool): If True, output summary statistics + for each threshold instead of full cluster information + + Returns: + SplinkDataFrame: A SplinkDataFrame containing cluster information for all + thresholds. If output_cluster_summary_stats is True, it contains summary + statistics (number of clusters, max cluster size, avg cluster size) for + each threshold. + + Examples: + ```python + from splink import DuckDBAPI + from splink.clustering import ( + cluster_pairwise_predictions_at_multiple_thresholds + ) + + db_api = DuckDBAPI() + + nodes = [ + {"my_id": 1}, + {"my_id": 2}, + {"my_id": 3}, + {"my_id": 4}, + {"my_id": 5}, + {"my_id": 6}, + ] + + edges = [ + {"n_1": 1, "n_2": 2, "match_probability": 0.8}, + {"n_1": 3, "n_2": 2, "match_probability": 0.9}, + {"n_1": 4, "n_2": 5, "match_probability": 0.99}, + ] + + thresholds = [0.5, 0.7, 0.9] + + cc = cluster_pairwise_predictions_at_multiple_thresholds( + nodes, + edges, + node_id_column_name="my_id", + edge_id_column_name_left="n_1", + edge_id_column_name_right="n_2", + db_api=db_api, + match_probability_thresholds=thresholds, + ) + + cc.as_duckdbpyrelation() + ``` + """ + + # Strategy to cluster at multiple thresholds: + # 1. Cluster at the lowest threshold + # 2. At next threshold, note that some clusters do not need to be recomputed. + # Specifically, those where the min probability within the cluster is + # greater than this next threshold. + # 3. Compute remaining clusters which _are_ affected by the next threshold. + # 4. Repeat for remaining thresholds. + + # Need to get nodes and edges in a format suitable to pass to + # cluster_pairwise_predictions_at_threshold + linker = self._linker + db_api = linker._db_api + + pipeline = CTEPipeline() + + enqueue_df_concat(linker, pipeline) + + uid_cols = linker._settings_obj.column_info_settings.unique_id_input_columns + uid_concat_edges_l = _composite_unique_id_from_edges_sql(uid_cols, "l") + uid_concat_edges_r = _composite_unique_id_from_edges_sql(uid_cols, "r") + uid_concat_nodes = _composite_unique_id_from_nodes_sql(uid_cols, None) + + + # Input could either be user data, or a SplinkDataFrame + sql = f""" + select + {uid_concat_nodes} as node_id + from __splink__df_concat + """ + pipeline.enqueue_sql(sql, "__splink__df_nodes_with_composite_ids") + + nodes_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + has_match_prob_col = "match_probability" in [ + c.unquote().name for c in df_predict.columns + ] + + is_match_weight = ( + match_weight_thresholds is not None and match_probability_thresholds is None + ) + + match_probability_thresholds = threshold_args_to_match_prob_list( + match_probability_thresholds, match_weight_thresholds + ) + + if match_probability_thresholds is None or len(match_probability_thresholds) == 0: + raise ValueError( + "Must provide either match_probability_thresholds " + "or match_weight_thresholds" + ) + + if not has_match_prob_col and match_probability_thresholds is not None: + raise ValueError( + "df_predict must have a column called 'match_probability' if " + "threshold_match_probability is provided" + ) + + initial_threshold = match_probability_thresholds.pop(0) + all_results = {} + + # Templated name must be used here because it could be the output + # of a deterministic link i.e. the physical name is not know for sure + sql = f""" + select + {uid_concat_edges_l} as node_id_l, + {uid_concat_edges_r} as node_id_r, + match_probability + from {df_predict.physical_name} + """ + pipeline.enqueue_sql(sql, "__splink__df_edges_from_predict") + + edges_table_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe( + pipeline + ) + + # First cluster at the lowest threshold + cc = cluster_pairwise_predictions_at_threshold( + nodes=nodes_with_composite_ids, + edges=edges_table_with_composite_ids, + db_api=db_api, + node_id_column_name="node_id", + edge_id_column_name_left="node_id_l", + edge_id_column_name_right="node_id_r", + threshold_match_probability=initial_threshold, + ) + + if output_cluster_summary_stats: + pipeline = CTEPipeline([cc]) + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results[initial_threshold] = cc_summary + else: + all_results[initial_threshold] = cc + + previous_threshold = initial_threshold + for new_threshold in match_probability_thresholds: + # Get stable nodes + logger.info(f"--------Clustering at threshold {new_threshold}--------") + pipeline = CTEPipeline([cc, edges_table_with_composite_ids]) + + sqls = _calculate_stable_clusters_at_new_threshold( + edges_sdf=edges_table_with_composite_ids, + cc=cc, + node_id_column_name="node_id", + edge_id_column_name_left=uid_concat_edges_l, + edge_id_column_name_right=uid_concat_edges_r, + previous_threshold_match_probability=previous_threshold, + new_threshold_match_probability=new_threshold, + ) + + pipeline.enqueue_list_of_sqls(sqls) + stable_clusters = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + # Derive nodes in play and edges in play by removing stable nodes. Then + # run cluster_pairwise_predictions_at_threshold at new threhold + + pipeline = CTEPipeline([nodes_with_composite_ids, stable_clusters]) + sql = f""" + select * + from {nodes_with_composite_ids.templated_name} + where node_id not in + (select node_id from {stable_clusters.templated_name}) + """ + pipeline.enqueue_sql(sql, "__splink__nodes_in_play") + nodes_in_play = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + pipeline = CTEPipeline([nodes_in_play, edges_table_with_composite_ids]) + sql = f""" + select * + from {edges_table_with_composite_ids.templated_name} + where {uid_concat_edges_l} in + (select node_id from {nodes_in_play.templated_name}) + and {uid_concat_edges_r} in + (select node_id from {nodes_in_play.templated_name}) + """ + pipeline.enqueue_sql(sql, "__splink__edges_in_play") + edges_in_play = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + marginal_new_clusters = cluster_pairwise_predictions_at_threshold( + nodes_in_play, + edges_in_play, + node_id_column_name="node_id", + edge_id_column_name_left=uid_concat_edges_l, + edge_id_column_name_right=uid_concat_edges_r, + db_api=db_api, + threshold_match_probability=new_threshold, + ) + + pipeline = CTEPipeline([stable_clusters, marginal_new_clusters]) + sql = f""" + SELECT node_id, cluster_id + FROM {stable_clusters.templated_name} + UNION ALL + SELECT node_id, cluster_id + FROM {marginal_new_clusters.templated_name} + """ + + pipeline.enqueue_sql(sql, "__splink__clusters_at_threshold") + + previous_cc = cc + cc = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + previous_threshold = new_threshold + + edges_in_play.drop_table_from_database_and_remove_from_cache() + nodes_in_play.drop_table_from_database_and_remove_from_cache() + stable_clusters.drop_table_from_database_and_remove_from_cache() + marginal_new_clusters.drop_table_from_database_and_remove_from_cache() + + if output_cluster_summary_stats: + pipeline = CTEPipeline([cc]) + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results[new_threshold] = cc_summary + previous_cc.drop_table_from_database_and_remove_from_cache() + else: + all_results[new_threshold] = cc + + if output_cluster_summary_stats: + sql = _generate_cluster_summary_stats_sql(all_results) + else: + sql = _generate_detailed_cluster_comparison_sql( + all_results, + unique_id_col="node_id", + is_match_weight=is_match_weight, + ) + + pipeline = CTEPipeline() + pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds") + joined = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + for v in all_results.values(): + v.drop_table_from_database_and_remove_from_cache() + cc.drop_table_from_database_and_remove_from_cache() + + return joined + def _compute_metrics_nodes( self, df_predict: SplinkDataFrame, From efc731f5734f67647a29b8c7b9d9d4d3d139a589 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Feb 2025 16:07:51 +0000 Subject: [PATCH 02/10] working version of multi threshold clustering --- splink/internals/clustering.py | 2 +- .../internals/linker_components/clustering.py | 51 +++++++++++++++---- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/splink/internals/clustering.py b/splink/internals/clustering.py index 036ed40e25..dcf343c83b 100644 --- a/splink/internals/clustering.py +++ b/splink/internals/clustering.py @@ -474,7 +474,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( edge_id_column_name_left, edge_id_column_name_right, ) - + logger.info(f"--------Clustering at threshold {initial_threshold}--------") # First cluster at the lowest threshold cc = cluster_pairwise_predictions_at_threshold( nodes=nodes_sdf, diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index da1302fae0..781c97861a 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: from splink.internals.linker import Linker +logger = logging.getLogger(__name__) class LinkerClustering: """Cluster the results of the linkage model and analyse clusters, accessed via @@ -327,21 +328,30 @@ def cluster_pairwise_predictions_at_multiple_thresholds( initial_threshold = match_probability_thresholds.pop(0) all_results = {} + match_p_expr = "" + match_p_select_expr = "" + if initial_threshold is not None: + match_p_expr = f"where match_probability >= {initial_threshold}" + match_p_select_expr = ", match_probability" + + pipeline = CTEPipeline([df_predict]) + # Templated name must be used here because it could be the output # of a deterministic link i.e. the physical name is not know for sure sql = f""" select {uid_concat_edges_l} as node_id_l, - {uid_concat_edges_r} as node_id_r, - match_probability - from {df_predict.physical_name} + {uid_concat_edges_r} as node_id_r + {match_p_select_expr} + from {df_predict.templated_name} + {match_p_expr} """ pipeline.enqueue_sql(sql, "__splink__df_edges_from_predict") edges_table_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe( pipeline ) - + logger.info(f"--------Clustering at threshold {initial_threshold}--------") # First cluster at the lowest threshold cc = cluster_pairwise_predictions_at_threshold( nodes=nodes_with_composite_ids, @@ -372,8 +382,8 @@ def cluster_pairwise_predictions_at_multiple_thresholds( edges_sdf=edges_table_with_composite_ids, cc=cc, node_id_column_name="node_id", - edge_id_column_name_left=uid_concat_edges_l, - edge_id_column_name_right=uid_concat_edges_r, + edge_id_column_name_left="node_id_l", + edge_id_column_name_right="node_id_r", previous_threshold_match_probability=previous_threshold, new_threshold_match_probability=new_threshold, ) @@ -398,9 +408,9 @@ def cluster_pairwise_predictions_at_multiple_thresholds( sql = f""" select * from {edges_table_with_composite_ids.templated_name} - where {uid_concat_edges_l} in + where node_id_l in (select node_id from {nodes_in_play.templated_name}) - and {uid_concat_edges_r} in + and node_id_r in (select node_id from {nodes_in_play.templated_name}) """ pipeline.enqueue_sql(sql, "__splink__edges_in_play") @@ -410,8 +420,8 @@ def cluster_pairwise_predictions_at_multiple_thresholds( nodes_in_play, edges_in_play, node_id_column_name="node_id", - edge_id_column_name_left=uid_concat_edges_l, - edge_id_column_name_right=uid_concat_edges_r, + edge_id_column_name_left="node_id_l", + edge_id_column_name_right="node_id_r", db_api=db_api, threshold_match_probability=new_threshold, ) @@ -460,11 +470,30 @@ def cluster_pairwise_predictions_at_multiple_thresholds( pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds") joined = db_api.sql_pipeline_to_splink_dataframe(pipeline) + columns = concat_table_column_names(self._linker) + # don't want to include salting column in output if present + columns_without_salt = filter(lambda x: x != "__splink_salt", columns) + + select_columns_sql = ", ".join(columns_without_salt) + + pipeline = CTEPipeline([joined]) + sql = f""" + select + co.*, + {select_columns_sql} + from {joined.physical_name} as co + left join __splink__df_concat + on co.node_id = {uid_concat_nodes} + """ + pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds_with_input_data") + + df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline) + for v in all_results.values(): v.drop_table_from_database_and_remove_from_cache() cc.drop_table_from_database_and_remove_from_cache() - return joined + return df_clustered_with_input_data def _compute_metrics_nodes( self, From cb7d9c42fa274cb945f66dce1a5afc512d9ab728 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Feb 2025 16:57:40 +0000 Subject: [PATCH 03/10] add metadata --- splink/internals/linker_components/clustering.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index 3873195720..95590243fc 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -650,6 +650,10 @@ def cluster_pairwise_predictions_at_multiple_thresholds( v.drop_table_from_database_and_remove_from_cache() cc.drop_table_from_database_and_remove_from_cache() + df_clustered_with_input_data.metadata["threshold_match_probabilities"] = ( + [initial_threshold] + match_probability_thresholds + ) + return df_clustered_with_input_data def _compute_metrics_nodes( From 8217f2a89e382e4c7f11e0eb1e3de5831b0059fc Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Feb 2025 17:08:55 +0000 Subject: [PATCH 04/10] update docstring --- .../internals/linker_components/clustering.py | 95 ++++++------------- 1 file changed, 30 insertions(+), 65 deletions(-) diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index 95590243fc..40420faa72 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -348,80 +348,45 @@ def cluster_using_single_best_links( def cluster_pairwise_predictions_at_multiple_thresholds( self, df_predict: SplinkDataFrame, - match_probability_thresholds: Optional[list[float]] | None = None, - match_weight_thresholds: Optional[list[float]] | None = None, + threshold_match_probabilities: Optional[list[float]] | None = None, + threshold_match_weights: Optional[list[float]] | None = None, output_cluster_summary_stats: bool = False, ) -> SplinkDataFrame: - """Clusters the pairwise match predictions at multiple thresholds using - the connected components graph clustering algorithm. + """Clusters the pairwise match predictions that result from + `linker.inference.predict()` into groups of connected record using the connected + components graph clustering algorithm + + Records with an estimated `match_probability` at or above each of the values in + `threshold_match_probabilities` (or records with a `match_weight` at or above + each of the values in `threshold_match_weights`) are considered to be a match + (i.e. they represent the same entity). This function efficiently computes clusters for multiple thresholds by starting with the lowest threshold and incrementally updating clusters for higher thresholds. - If your node and edge column names follow Splink naming conventions, then you can - omit edge_id_column_name_left and edge_id_column_name_right. For example, if you - have a table of nodes with a column `unique_id`, it would be assumed that the - edge table has columns `unique_id_l` and `unique_id_r`. - Args: - nodes (AcceptableInputTableType): The table containing node information - edges (AcceptableInputTableType): The table containing edge information - db_api (DatabaseAPISubClass): The database API to use for querying - node_id_column_name (str): The name of the column containing node IDs - match_probability_thresholds (list[float] | None): List of match probability + df_predict (SplinkDataFrame): The results of `linker.predict()` + threshold_match_probabilities (list[float] | None): List of match probability thresholds to compute clusters for - match_weight_thresholds (list[float] | None): List of match weight thresholds + threshold_match_weights (list[float] | None): List of match weight thresholds to compute clusters for - edge_id_column_name_left (Optional[str]): The name of the column containing - left edge IDs. If not provided, assumed to be f"{node_id_column_name}_l" - edge_id_column_name_right (Optional[str]): The name of the column containing - right edge IDs. If not provided, assumed to be f"{node_id_column_name}_r" output_cluster_summary_stats (bool): If True, output summary statistics for each threshold instead of full cluster information Returns: - SplinkDataFrame: A SplinkDataFrame containing cluster information for all - thresholds. If output_cluster_summary_stats is True, it contains summary + SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered + into groups for each of the desired match thresholds. + If output_cluster_summary_stats is True, it contains summary statistics (number of clusters, max cluster size, avg cluster size) for each threshold. Examples: ```python - from splink import DuckDBAPI - from splink.clustering import ( - cluster_pairwise_predictions_at_multiple_thresholds - ) - - db_api = DuckDBAPI() - - nodes = [ - {"my_id": 1}, - {"my_id": 2}, - {"my_id": 3}, - {"my_id": 4}, - {"my_id": 5}, - {"my_id": 6}, - ] - - edges = [ - {"n_1": 1, "n_2": 2, "match_probability": 0.8}, - {"n_1": 3, "n_2": 2, "match_probability": 0.9}, - {"n_1": 4, "n_2": 5, "match_probability": 0.99}, - ] - - thresholds = [0.5, 0.7, 0.9] - - cc = cluster_pairwise_predictions_at_multiple_thresholds( - nodes, - edges, - node_id_column_name="my_id", - edge_id_column_name_left="n_1", - edge_id_column_name_right="n_2", - db_api=db_api, - match_probability_thresholds=thresholds, + df_predict = linker.inference.predict(threshold_match_probability=0.5) + df_clustered = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( + df_predict, threshold_match_probability=0.95 ) - - cc.as_duckdbpyrelation() + ``` ``` """ @@ -463,26 +428,26 @@ def cluster_pairwise_predictions_at_multiple_thresholds( ] is_match_weight = ( - match_weight_thresholds is not None and match_probability_thresholds is None + threshold_match_weights is not None and threshold_match_probabilities is None ) - match_probability_thresholds = threshold_args_to_match_prob_list( - match_probability_thresholds, match_weight_thresholds + threshold_match_probabilities = threshold_args_to_match_prob_list( + threshold_match_probabilities, threshold_match_weights ) - if match_probability_thresholds is None or len(match_probability_thresholds) == 0: + if threshold_match_probabilities is None or len(threshold_match_probabilities) == 0: raise ValueError( - "Must provide either match_probability_thresholds " - "or match_weight_thresholds" + "Must provide either threshold_match_probabilities " + "or threshold_match_weights" ) - if not has_match_prob_col and match_probability_thresholds is not None: + if not has_match_prob_col and threshold_match_probabilities is not None: raise ValueError( "df_predict must have a column called 'match_probability' if " "threshold_match_probability is provided" ) - initial_threshold = match_probability_thresholds.pop(0) + initial_threshold = threshold_match_probabilities.pop(0) all_results = {} match_p_expr = "" @@ -530,7 +495,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( all_results[initial_threshold] = cc previous_threshold = initial_threshold - for new_threshold in match_probability_thresholds: + for new_threshold in threshold_match_probabilities: # Get stable nodes logger.info(f"--------Clustering at threshold {new_threshold}--------") pipeline = CTEPipeline([cc, edges_table_with_composite_ids]) @@ -651,7 +616,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( cc.drop_table_from_database_and_remove_from_cache() df_clustered_with_input_data.metadata["threshold_match_probabilities"] = ( - [initial_threshold] + match_probability_thresholds + [initial_threshold] + threshold_match_probabilities ) return df_clustered_with_input_data From a253d4b38dab3d5c9139acf64d28fa4cae17e77d Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Feb 2025 17:33:02 +0000 Subject: [PATCH 05/10] update docs --- docs/api_docs/clustering.md | 5 ++++- docs/api_docs/linker_clustering.md | 19 +++++++++++++++++++ mkdocs.yml | 3 ++- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 docs/api_docs/linker_clustering.md diff --git a/docs/api_docs/clustering.md b/docs/api_docs/clustering.md index 725a4b0d25..71bc7ee31a 100644 --- a/docs/api_docs/clustering.md +++ b/docs/api_docs/clustering.md @@ -1,10 +1,13 @@ --- tags: - API - - clustering + - Clustering --- + # Documentation for `splink.clustering` +Clustering at one or multiple thresholds is also available without a `linker` object: + ::: splink.clustering handler: python options: diff --git a/docs/api_docs/linker_clustering.md b/docs/api_docs/linker_clustering.md new file mode 100644 index 0000000000..73d22d8820 --- /dev/null +++ b/docs/api_docs/linker_clustering.md @@ -0,0 +1,19 @@ +--- +tags: + - API + - Clustering +--- + +# Methods in Linker.clustering + +Use the result of your Splink model to group (cluster) records together. Accessed via `linker.clustering` + +::: splink.internals.linker_components.clustering.LinkerClustering + handler: python + filters: + - "!^__init__$" + options: + show_root_heading: false + show_root_toc: false + show_source: false + members_order: source \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index db60143b9e..5a11b4863c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -138,7 +138,7 @@ nav: - Training: "api_docs/training.md" - Visualisations: "api_docs/visualisations.md" - Inference: "api_docs/inference.md" - - Clustering: "api_docs/clustering.md" + - Clustering: "api_docs/linker_clustering.md" - Evaluation: "api_docs/evaluation.md" - Table Management: "api_docs/table_management.md" - Miscellaneous functions: "api_docs/misc.md" @@ -149,6 +149,7 @@ nav: - Exploratory: "api_docs/exploratory.md" - Blocking rule creator: "api_docs/blocking.md" - Blocking analysis: "api_docs/blocking_analysis.md" + - Clustering: "api_docs/clustering.md" - SplinkDataFrame: "api_docs/splink_dataframe.md" - EM Training Session API: "api_docs/em_training_session.md" - SplinkDatasets: "api_docs/datasets.md" From 3f6d09d9c9145c2d8ebd2bb8e63ff6db64828c81 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Feb 2025 17:46:06 +0000 Subject: [PATCH 06/10] formatting issues --- splink/clustering.py | 10 +++++-- .../internals/linker_components/clustering.py | 29 +++++++++++-------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/splink/clustering.py b/splink/clustering.py index 5a45817a50..08faed11b0 100644 --- a/splink/clustering.py +++ b/splink/clustering.py @@ -1,3 +1,9 @@ -from .internals.clustering import cluster_pairwise_predictions_at_threshold, cluster_pairwise_predictions_at_multiple_thresholds +from .internals.clustering import ( + cluster_pairwise_predictions_at_multiple_thresholds, + cluster_pairwise_predictions_at_threshold +) -__all__ = ["cluster_pairwise_predictions_at_threshold", "cluster_pairwise_predictions_at_multiple_thresholds"] +__all__ = [ + "cluster_pairwise_predictions_at_threshold", + "cluster_pairwise_predictions_at_multiple_thresholds" + ] diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index 40420faa72..743438a8a5 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -358,24 +358,25 @@ def cluster_pairwise_predictions_at_multiple_thresholds( Records with an estimated `match_probability` at or above each of the values in `threshold_match_probabilities` (or records with a `match_weight` at or above - each of the values in `threshold_match_weights`) are considered to be a match + each of the values in `threshold_match_weights`) are considered to be a match (i.e. they represent the same entity). This function efficiently computes clusters for multiple thresholds by starting - with the lowest threshold and incrementally updating clusters for higher thresholds. + with the lowest threshold and incrementally updating clusters for higher + thresholds. Args: df_predict (SplinkDataFrame): The results of `linker.predict()` - threshold_match_probabilities (list[float] | None): List of match probability - thresholds to compute clusters for - threshold_match_weights (list[float] | None): List of match weight thresholds - to compute clusters for + threshold_match_probabilities (list[float] | None): List of match + probability thresholds to compute clusters for + threshold_match_weights (list[float] | None): List of match weight + thresholds to compute clusters for output_cluster_summary_stats (bool): If True, output summary statistics for each threshold instead of full cluster information Returns: SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered - into groups for each of the desired match thresholds. + into groups for each of the desired match thresholds. If output_cluster_summary_stats is True, it contains summary statistics (number of clusters, max cluster size, avg cluster size) for each threshold. @@ -428,19 +429,21 @@ def cluster_pairwise_predictions_at_multiple_thresholds( ] is_match_weight = ( - threshold_match_weights is not None and threshold_match_probabilities is None + threshold_match_weights is not None + and threshold_match_probabilities is None ) threshold_match_probabilities = threshold_args_to_match_prob_list( threshold_match_probabilities, threshold_match_weights ) - if threshold_match_probabilities is None or len(threshold_match_probabilities) == 0: + if (threshold_match_probabilities is None + or len(threshold_match_probabilities) == 0): raise ValueError( "Must provide either threshold_match_probabilities " "or threshold_match_weights" ) - + if not has_match_prob_col and threshold_match_probabilities is not None: raise ValueError( "df_predict must have a column called 'match_probability' if " @@ -607,7 +610,9 @@ def cluster_pairwise_predictions_at_multiple_thresholds( left join __splink__df_concat on co.node_id = {uid_concat_nodes} """ - pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds_with_input_data") + pipeline.enqueue_sql(sql, + "__splink__clusters_at_all_thresholds_with_input_data" + ) df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline) @@ -618,7 +623,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( df_clustered_with_input_data.metadata["threshold_match_probabilities"] = ( [initial_threshold] + threshold_match_probabilities ) - + return df_clustered_with_input_data def _compute_metrics_nodes( From 99164cfc557d29621d78fe3fd73421ee24e236db Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Wed, 12 Feb 2025 11:23:12 +0000 Subject: [PATCH 07/10] fix up examples to include metadata --- .../internals/linker_components/clustering.py | 100 +++++++++++------- 1 file changed, 62 insertions(+), 38 deletions(-) diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index 743438a8a5..249a1c17f1 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -350,7 +350,6 @@ def cluster_pairwise_predictions_at_multiple_thresholds( df_predict: SplinkDataFrame, threshold_match_probabilities: Optional[list[float]] | None = None, threshold_match_weights: Optional[list[float]] | None = None, - output_cluster_summary_stats: bool = False, ) -> SplinkDataFrame: """Clusters the pairwise match predictions that result from `linker.inference.predict()` into groups of connected record using the connected @@ -371,15 +370,17 @@ def cluster_pairwise_predictions_at_multiple_thresholds( probability thresholds to compute clusters for threshold_match_weights (list[float] | None): List of match weight thresholds to compute clusters for - output_cluster_summary_stats (bool): If True, output summary statistics - for each threshold instead of full cluster information Returns: SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered into groups for each of the desired match thresholds. - If output_cluster_summary_stats is True, it contains summary - statistics (number of clusters, max cluster size, avg cluster size) for - each threshold. + + The output dataframe will contain the following metadata: + + - threshold_match_probabilities: List of match probability thresholds + + - cluster_summary_stats: summary statistics (number of clusters, max + cluster size, avg cluster size) for each threshold Examples: ```python @@ -387,8 +388,14 @@ def cluster_pairwise_predictions_at_multiple_thresholds( df_clustered = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( df_predict, threshold_match_probability=0.95 ) + + # Access the match probability thresholds + match_prob_thresholds = df_clustered.metadata["threshold_match_probabilities"] + + # Access the cluster summary stats + cluster_summary_stats = df_clustered.metadata["cluster_summary_stats"] ``` - ``` + """ # Strategy to cluster at multiple thresholds: @@ -452,6 +459,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( initial_threshold = threshold_match_probabilities.pop(0) all_results = {} + all_results_summary = {} match_p_expr = "" match_p_select_expr = "" @@ -488,15 +496,16 @@ def cluster_pairwise_predictions_at_multiple_thresholds( threshold_match_probability=initial_threshold, ) - if output_cluster_summary_stats: - pipeline = CTEPipeline([cc]) - sqls = _get_cluster_stats_sql(cc) - pipeline.enqueue_list_of_sqls(sqls) - cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) - all_results[initial_threshold] = cc_summary - else: - all_results[initial_threshold] = cc - + all_results[initial_threshold] = cc + + # Calculate Summary stats for first clustering threshold + pipeline = CTEPipeline([cc]) + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results_summary[initial_threshold] = cc_summary + + # Now iterate over the remaining thresholds previous_threshold = initial_threshold for new_threshold in threshold_match_probabilities: # Get stable nodes @@ -572,25 +581,21 @@ def cluster_pairwise_predictions_at_multiple_thresholds( stable_clusters.drop_table_from_database_and_remove_from_cache() marginal_new_clusters.drop_table_from_database_and_remove_from_cache() - if output_cluster_summary_stats: - pipeline = CTEPipeline([cc]) - sqls = _get_cluster_stats_sql(cc) - pipeline.enqueue_list_of_sqls(sqls) - cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) - all_results[new_threshold] = cc_summary - previous_cc.drop_table_from_database_and_remove_from_cache() - else: - all_results[new_threshold] = cc - - if output_cluster_summary_stats: - sql = _generate_cluster_summary_stats_sql(all_results) - else: - sql = _generate_detailed_cluster_comparison_sql( + all_results[new_threshold] = cc + + + # Calculate summary stats for metadata + pipeline = CTEPipeline([cc]) + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results_summary[new_threshold] = cc_summary + + sql = _generate_detailed_cluster_comparison_sql( all_results, unique_id_col="node_id", is_match_weight=is_match_weight, ) - pipeline = CTEPipeline() pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds") joined = db_api.sql_pipeline_to_splink_dataframe(pipeline) @@ -616,13 +621,26 @@ def cluster_pairwise_predictions_at_multiple_thresholds( df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline) - for v in all_results.values(): - v.drop_table_from_database_and_remove_from_cache() - cc.drop_table_from_database_and_remove_from_cache() + + # Add metadata to the output dataframe + ## Match probability thresholds df_clustered_with_input_data.metadata["threshold_match_probabilities"] = ( [initial_threshold] + threshold_match_probabilities ) + + ## Cluster Summary stats + pipeline = CTEPipeline() + sql = _generate_cluster_summary_stats_sql(all_results_summary) + pipeline.enqueue_sql(sql, "__splink__cluster_summary_stats") + df_clustered_with_input_data.metadata["cluster_summary_stats"] = ( + db_api.sql_pipeline_to_splink_dataframe(pipeline) + ) + + # Drop cached tables + for v in all_results.values(): + v.drop_table_from_database_and_remove_from_cache() + cc.drop_table_from_database_and_remove_from_cache() return df_clustered_with_input_data @@ -791,10 +809,13 @@ def compute_graph_metrics( Returns: GraphMetricsResult: A data class containing SplinkDataFrames - of cluster IDs and selected node, edge or cluster metrics. - attribute "nodes" for nodes metrics table - attribute "edges" for edge metrics table - attribute "clusters" for cluster metrics table + of cluster IDs and selected node, edge or cluster metrics. + + - attribute "nodes" for nodes metrics table + + - attribute "edges" for edge metrics table + + - attribute "clusters" for cluster metrics table Examples: ```python @@ -809,6 +830,9 @@ def compute_graph_metrics( node_metrics = graph_metrics.nodes.as_pandas_dataframe() edge_metrics = graph_metrics.edges.as_pandas_dataframe() cluster_metrics = graph_metrics.clusters.as_pandas_dataframe() + + # Access the match probability thresholds + probability_threshold = graph_metrics.nodes.metadata ``` """ if threshold_match_probability is None: From 5b15a6f686785f967fd0491a859f21c4b8b2ab38 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Thu, 13 Feb 2025 10:17:28 +0000 Subject: [PATCH 08/10] fix formatting --- splink/clustering.py | 10 +-- .../internals/linker_components/clustering.py | 63 +++++++++---------- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/splink/clustering.py b/splink/clustering.py index 08faed11b0..4cd1b23eeb 100644 --- a/splink/clustering.py +++ b/splink/clustering.py @@ -1,9 +1,9 @@ from .internals.clustering import ( - cluster_pairwise_predictions_at_multiple_thresholds, - cluster_pairwise_predictions_at_threshold + cluster_pairwise_predictions_at_multiple_thresholds, + cluster_pairwise_predictions_at_threshold, ) __all__ = [ - "cluster_pairwise_predictions_at_threshold", - "cluster_pairwise_predictions_at_multiple_thresholds" - ] + "cluster_pairwise_predictions_at_threshold", + "cluster_pairwise_predictions_at_multiple_thresholds", +] diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index 249a1c17f1..a2cc8cb76a 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -1,15 +1,14 @@ from __future__ import annotations import logging - from typing import TYPE_CHECKING, List, Optional from splink.internals.clustering import ( - cluster_pairwise_predictions_at_threshold, - _get_cluster_stats_sql, _calculate_stable_clusters_at_new_threshold, + _generate_cluster_summary_stats_sql, _generate_detailed_cluster_comparison_sql, - _generate_cluster_summary_stats_sql + _get_cluster_stats_sql, + cluster_pairwise_predictions_at_threshold, ) from splink.internals.connected_components import ( solve_connected_components, @@ -22,7 +21,7 @@ ) from splink.internals.misc import ( threshold_args_to_match_prob, - threshold_args_to_match_prob_list + threshold_args_to_match_prob_list, ) from splink.internals.one_to_one_clustering import ( one_to_one_clustering, @@ -43,6 +42,7 @@ logger = logging.getLogger(__name__) + class LinkerClustering: """Cluster the results of the linkage model and analyse clusters, accessed via `linker.clustering`. @@ -378,19 +378,21 @@ def cluster_pairwise_predictions_at_multiple_thresholds( The output dataframe will contain the following metadata: - threshold_match_probabilities: List of match probability thresholds - + - cluster_summary_stats: summary statistics (number of clusters, max cluster size, avg cluster size) for each threshold Examples: ```python df_predict = linker.inference.predict(threshold_match_probability=0.5) - df_clustered = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( + df_clustered = + linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( df_predict, threshold_match_probability=0.95 ) # Access the match probability thresholds - match_prob_thresholds = df_clustered.metadata["threshold_match_probabilities"] + match_prob_thresholds = df_clustered + .metadata["threshold_match_probabilities"] # Access the cluster summary stats cluster_summary_stats = df_clustered.metadata["cluster_summary_stats"] @@ -420,7 +422,6 @@ def cluster_pairwise_predictions_at_multiple_thresholds( uid_concat_edges_r = _composite_unique_id_from_edges_sql(uid_cols, "r") uid_concat_nodes = _composite_unique_id_from_nodes_sql(uid_cols, None) - # Input could either be user data, or a SplinkDataFrame sql = f""" select @@ -436,7 +437,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( ] is_match_weight = ( - threshold_match_weights is not None + threshold_match_weights is not None and threshold_match_probabilities is None ) @@ -444,8 +445,10 @@ def cluster_pairwise_predictions_at_multiple_thresholds( threshold_match_probabilities, threshold_match_weights ) - if (threshold_match_probabilities is None - or len(threshold_match_probabilities) == 0): + if ( + threshold_match_probabilities is None + or len(threshold_match_probabilities) == 0 + ): raise ValueError( "Must provide either threshold_match_probabilities " "or threshold_match_weights" @@ -497,14 +500,14 @@ def cluster_pairwise_predictions_at_multiple_thresholds( ) all_results[initial_threshold] = cc - + # Calculate Summary stats for first clustering threshold pipeline = CTEPipeline([cc]) sqls = _get_cluster_stats_sql(cc) pipeline.enqueue_list_of_sqls(sqls) cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) all_results_summary[initial_threshold] = cc_summary - + # Now iterate over the remaining thresholds previous_threshold = initial_threshold for new_threshold in threshold_match_probabilities: @@ -571,7 +574,6 @@ def cluster_pairwise_predictions_at_multiple_thresholds( pipeline.enqueue_sql(sql, "__splink__clusters_at_threshold") - previous_cc = cc cc = db_api.sql_pipeline_to_splink_dataframe(pipeline) previous_threshold = new_threshold @@ -582,7 +584,6 @@ def cluster_pairwise_predictions_at_multiple_thresholds( marginal_new_clusters.drop_table_from_database_and_remove_from_cache() all_results[new_threshold] = cc - # Calculate summary stats for metadata pipeline = CTEPipeline([cc]) @@ -590,12 +591,12 @@ def cluster_pairwise_predictions_at_multiple_thresholds( pipeline.enqueue_list_of_sqls(sqls) cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) all_results_summary[new_threshold] = cc_summary - + sql = _generate_detailed_cluster_comparison_sql( - all_results, - unique_id_col="node_id", - is_match_weight=is_match_weight, - ) + all_results, + unique_id_col="node_id", + is_match_weight=is_match_weight, + ) pipeline = CTEPipeline() pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds") joined = db_api.sql_pipeline_to_splink_dataframe(pipeline) @@ -615,20 +616,18 @@ def cluster_pairwise_predictions_at_multiple_thresholds( left join __splink__df_concat on co.node_id = {uid_concat_nodes} """ - pipeline.enqueue_sql(sql, - "__splink__clusters_at_all_thresholds_with_input_data" - ) + pipeline.enqueue_sql( + sql, "__splink__clusters_at_all_thresholds_with_input_data" + ) df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline) - - # Add metadata to the output dataframe ## Match probability thresholds - df_clustered_with_input_data.metadata["threshold_match_probabilities"] = ( - [initial_threshold] + threshold_match_probabilities - ) - + df_clustered_with_input_data.metadata["threshold_match_probabilities"] = [ + initial_threshold + ] + threshold_match_probabilities + ## Cluster Summary stats pipeline = CTEPipeline() sql = _generate_cluster_summary_stats_sql(all_results_summary) @@ -636,7 +635,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( df_clustered_with_input_data.metadata["cluster_summary_stats"] = ( db_api.sql_pipeline_to_splink_dataframe(pipeline) ) - + # Drop cached tables for v in all_results.values(): v.drop_table_from_database_and_remove_from_cache() @@ -814,7 +813,7 @@ def compute_graph_metrics( - attribute "nodes" for nodes metrics table - attribute "edges" for edge metrics table - + - attribute "clusters" for cluster metrics table Examples: From b7a155fff4828ef8383bd9cba53d1fc84d8e0d06 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Thu, 13 Feb 2025 14:21:01 +0000 Subject: [PATCH 09/10] add tests --- tests/test_clustering.py | 93 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/tests/test_clustering.py b/tests/test_clustering.py index a15f6dc1d3..0afb7a4266 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -4,6 +4,9 @@ import splink.comparison_library as cl from splink import DuckDBAPI, Linker, SettingsCreator, block_on +from splink.clustering import ( + cluster_pairwise_predictions_at_multiple_thresholds +) from .basic_settings import get_settings_dict from .decorator import mark_with_dialects_excluding @@ -126,3 +129,93 @@ def test_clustering_no_edges(test_helpers, dialect): # due to blocking rules, df_predict will be empty df_predict = linker.inference.predict() linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95) + +@mark_with_dialects_excluding() +@mark.parametrize( + ["link_type", "input_pd_tables"], + [ + ["dedupe_only", [df]], + ["link_only", [df, df]], # no source dataset + ["link_only", [df_l, df_r]], # source dataset column + ["link_only", [df_combined]], # concatenated frame + ["link_only", [df_l, df_m, df_r]], + ["link_and_dedupe", [df, df]], # no source dataset + ["link_and_dedupe", [df_l, df_r]], # source dataset column + ["link_and_dedupe", [df_combined]], # concatenated frame + ], + ids=[ + "dedupe", + "link_only_no_source_dataset", + "link_only_with_source_dataset", + "link_only_concat", + "link_only_three_tables", + "link_and_dedupe_no_source_dataset", + "link_and_dedupe_with_source_dataset", + "link_and_dedupe_concat", + ], +) +def test_clustering_multi_thresholds(test_helpers, dialect, link_type, input_pd_tables): + helper = test_helpers[dialect] + + settings = SettingsCreator( + link_type=link_type, + comparisons=[ + cl.ExactMatch("first_name"), + cl.ExactMatch("surname"), + cl.ExactMatch("dob"), + cl.ExactMatch("city"), + ], + blocking_rules_to_generate_predictions=[ + block_on("surname"), + block_on("dob"), + ], + ) + linker_input = list(map(helper.convert_frame, input_pd_tables)) + linker = Linker(linker_input, settings, **helper.extra_linker_args()) + + df_predict = linker.inference.predict() + linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95]) + + +def test_clustering_single_multi_threshold_equivalence(): + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + db_api = DuckDBAPI() + settings_dict = get_settings_dict() + linker = Linker(df, settings_dict, db_api=db_api) + + df_predict = linker.inference.predict() + + clusters_0_5 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.5).as_pandas_dataframe() + clusters_0_95 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95).as_pandas_dataframe() + + clusters_multi = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95]).as_pandas_dataframe() + + assert clusters_0_5["cluster_id"] == clusters_multi["cluster_p_0_5"] + assert clusters_0_95["cluster_id"] == clusters_multi["cluster_p_0_95"] + + + +def test_clustering_multi_threshold_linker_non_linker_equivalence(): + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + db_api = DuckDBAPI() + settings_dict = get_settings_dict() + linker = Linker(df, settings_dict, db_api=db_api) + + df_predict = linker.inference.predict() + + clusters_linker = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( + df_predict, + [0.5, 0.95] + ).as_pandas_dataframe() + clusters_non_linker = cluster_pairwise_predictions_at_multiple_thresholds( + df, + df_predict, + node_id_column_name="unique_id", + edge_id_column_name_left="unique_id_l", + edge_id_column_name_right="unique_id_r", + db_api=linker.db_api, + match_probability_thresholds=[0.5, 0.95] + ).as_pandas_dataframe() + + assert clusters_linker["cluster_p_0_5"] == clusters_non_linker["cluster_p_0_5"] + assert clusters_linker["cluster_p_0_95"] == clusters_non_linker["cluster_p_0_95"] \ No newline at end of file From 2c60d74b85d6a2b1e3a63ae79f8e898efa4211d5 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Thu, 13 Feb 2025 16:27:25 +0000 Subject: [PATCH 10/10] get working and start building up tests --- .../internals/linker_components/clustering.py | 8 +++-- tests/test_clustering.py | 31 +++++++++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index a2cc8cb76a..f5c1f59e9e 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -33,6 +33,7 @@ _composite_unique_id_from_nodes_sql, ) from splink.internals.vertically_concatenate import ( + compute_df_concat, concat_table_column_names, enqueue_df_concat, ) @@ -601,19 +602,22 @@ def cluster_pairwise_predictions_at_multiple_thresholds( pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds") joined = db_api.sql_pipeline_to_splink_dataframe(pipeline) + pipeline = CTEPipeline() + concat = compute_df_concat(linker, pipeline) + columns = concat_table_column_names(self._linker) # don't want to include salting column in output if present columns_without_salt = filter(lambda x: x != "__splink_salt", columns) select_columns_sql = ", ".join(columns_without_salt) - pipeline = CTEPipeline([joined]) + pipeline = CTEPipeline([joined, concat]) sql = f""" select co.*, {select_columns_sql} from {joined.physical_name} as co - left join __splink__df_concat + left join {concat.physical_name} as c on co.node_id = {uid_concat_nodes} """ pipeline.enqueue_sql( diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 0afb7a4266..a262d0d71f 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -185,13 +185,21 @@ def test_clustering_single_multi_threshold_equivalence(): df_predict = linker.inference.predict() - clusters_0_5 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.5).as_pandas_dataframe() - clusters_0_95 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95).as_pandas_dataframe() + clusters_0_5 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.5).as_pandas_dataframe() + clusters_0_9 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.9).as_pandas_dataframe() - clusters_multi = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95]).as_pandas_dataframe() + clusters_multi = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.9]).as_pandas_dataframe() - assert clusters_0_5["cluster_id"] == clusters_multi["cluster_p_0_5"] - assert clusters_0_95["cluster_id"] == clusters_multi["cluster_p_0_95"] + df = pd.merge(clusters_0_5, clusters_multi, left_on='unique_id', right_on='unique_id', how='inner') + + df["different"] = df["cluster_id"] != df["cluster_p_0_9"] + compare = df[["cluster_id", "cluster_p_0_9", "different"]] + df.sort_values(by='different', ascending=False, inplace=True) + print(compare[compare["different"]==True]) + print(sum(compare["different"])) + + assert clusters_0_5["cluster_id"].equals(clusters_multi["cluster_p_0_5"]) + assert clusters_0_9["cluster_id"].equals(clusters_multi["cluster_p_0_9"]) @@ -213,9 +221,14 @@ def test_clustering_multi_threshold_linker_non_linker_equivalence(): node_id_column_name="unique_id", edge_id_column_name_left="unique_id_l", edge_id_column_name_right="unique_id_r", - db_api=linker.db_api, + db_api=linker._db_api, match_probability_thresholds=[0.5, 0.95] ).as_pandas_dataframe() - - assert clusters_linker["cluster_p_0_5"] == clusters_non_linker["cluster_p_0_5"] - assert clusters_linker["cluster_p_0_95"] == clusters_non_linker["cluster_p_0_95"] \ No newline at end of file + df = pd.DataFrame({'linker': clusters_linker['cluster_p_0_5'], 'non-linker': clusters_non_linker['cluster_p_0_5']}) + # df["different"] = df["linker"] != df["non-linker"] + # df.sort_values(by='different', ascending=False, inplace=True) + # print(df) + # print(sum(df["different"])) + + #assert clusters_linker["cluster_p_0_5"].equals(clusters_non_linker["cluster_p_0_5"]) + #assert clusters_linker["cluster_p_0_95"].equals(clusters_non_linker["cluster_p_0_95"]) \ No newline at end of file