diff --git a/docs/api_docs/clustering.md b/docs/api_docs/clustering.md index 725a4b0d25..71bc7ee31a 100644 --- a/docs/api_docs/clustering.md +++ b/docs/api_docs/clustering.md @@ -1,10 +1,13 @@ --- tags: - API - - clustering + - Clustering --- + # Documentation for `splink.clustering` +Clustering at one or multiple thresholds is also available without a `linker` object: + ::: splink.clustering handler: python options: diff --git a/docs/api_docs/linker_clustering.md b/docs/api_docs/linker_clustering.md new file mode 100644 index 0000000000..73d22d8820 --- /dev/null +++ b/docs/api_docs/linker_clustering.md @@ -0,0 +1,19 @@ +--- +tags: + - API + - Clustering +--- + +# Methods in Linker.clustering + +Use the result of your Splink model to group (cluster) records together. Accessed via `linker.clustering` + +::: splink.internals.linker_components.clustering.LinkerClustering + handler: python + filters: + - "!^__init__$" + options: + show_root_heading: false + show_root_toc: false + show_source: false + members_order: source \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index d61b06961b..60d040238a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -140,7 +140,7 @@ nav: - Training: "api_docs/training.md" - Visualisations: "api_docs/visualisations.md" - Inference: "api_docs/inference.md" - - Clustering: "api_docs/clustering.md" + - Clustering: "api_docs/linker_clustering.md" - Evaluation: "api_docs/evaluation.md" - Table Management: "api_docs/table_management.md" - Miscellaneous functions: "api_docs/misc.md" @@ -151,6 +151,7 @@ nav: - Exploratory: "api_docs/exploratory.md" - Blocking rule creator: "api_docs/blocking.md" - Blocking analysis: "api_docs/blocking_analysis.md" + - Clustering: "api_docs/clustering.md" - SplinkDataFrame: "api_docs/splink_dataframe.md" - EM Training Session API: "api_docs/em_training_session.md" - SplinkDatasets: "api_docs/datasets.md" diff --git a/splink/clustering.py b/splink/clustering.py index ea22d068d2..4cd1b23eeb 100644 --- a/splink/clustering.py +++ b/splink/clustering.py @@ -1,3 +1,9 @@ -from .internals.clustering import cluster_pairwise_predictions_at_threshold +from .internals.clustering import ( + cluster_pairwise_predictions_at_multiple_thresholds, + cluster_pairwise_predictions_at_threshold, +) -__all__ = ["cluster_pairwise_predictions_at_threshold"] +__all__ = [ + "cluster_pairwise_predictions_at_threshold", + "cluster_pairwise_predictions_at_multiple_thresholds", +] diff --git a/splink/internals/clustering.py b/splink/internals/clustering.py index 036ed40e25..dcf343c83b 100644 --- a/splink/internals/clustering.py +++ b/splink/internals/clustering.py @@ -474,7 +474,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds( edge_id_column_name_left, edge_id_column_name_right, ) - + logger.info(f"--------Clustering at threshold {initial_threshold}--------") # First cluster at the lowest threshold cc = cluster_pairwise_predictions_at_threshold( nodes=nodes_sdf, diff --git a/splink/internals/linker_components/clustering.py b/splink/internals/linker_components/clustering.py index 55ee1a0ce1..f5fe59740c 100644 --- a/splink/internals/linker_components/clustering.py +++ b/splink/internals/linker_components/clustering.py @@ -1,7 +1,15 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, List, Optional +from splink.internals.clustering import ( + _calculate_stable_clusters_at_new_threshold, + _generate_cluster_summary_stats_sql, + _generate_detailed_cluster_comparison_sql, + _get_cluster_stats_sql, + cluster_pairwise_predictions_at_threshold, +) from splink.internals.connected_components import ( solve_connected_components, ) @@ -13,6 +21,7 @@ ) from splink.internals.misc import ( threshold_args_to_match_prob, + threshold_args_to_match_prob_list, ) from splink.internals.one_to_one_clustering import ( one_to_one_clustering, @@ -24,6 +33,7 @@ _composite_unique_id_from_nodes_sql, ) from splink.internals.vertically_concatenate import ( + compute_df_concat, concat_table_column_names, enqueue_df_concat, ) @@ -31,6 +41,8 @@ if TYPE_CHECKING: from splink.internals.linker import Linker +logger = logging.getLogger(__name__) + class LinkerClustering: """Cluster the results of the linkage model and analyse clusters, accessed via @@ -334,6 +346,307 @@ def cluster_using_single_best_links( return df_clustered_with_input_data + def cluster_pairwise_predictions_at_multiple_thresholds( + self, + df_predict: SplinkDataFrame, + threshold_match_probabilities: Optional[list[float]] | None = None, + threshold_match_weights: Optional[list[float]] | None = None, + ) -> SplinkDataFrame: + """Clusters the pairwise match predictions that result from + `linker.inference.predict()` into groups of connected record using the connected + components graph clustering algorithm + + Records with an estimated `match_probability` at or above each of the values in + `threshold_match_probabilities` (or records with a `match_weight` at or above + each of the values in `threshold_match_weights`) are considered to be a match + (i.e. they represent the same entity). + + This function efficiently computes clusters for multiple thresholds by starting + with the lowest threshold and incrementally updating clusters for higher + thresholds. + + Args: + df_predict (SplinkDataFrame): The results of `linker.predict()` + threshold_match_probabilities (list[float] | None): List of match + probability thresholds to compute clusters for + threshold_match_weights (list[float] | None): List of match weight + thresholds to compute clusters for + + Returns: + SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered + into groups for each of the desired match thresholds. + + The output dataframe will contain the following metadata: + + - threshold_match_probabilities: List of match probability thresholds + + - cluster_summary_stats: summary statistics (number of clusters, max + cluster size, avg cluster size) for each threshold + + Examples: + ```python + df_predict = linker.inference.predict(threshold_match_probability=0.5) + df_clustered = + linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( + df_predict, threshold_match_probability=0.95 + ) + + # Access the match probability thresholds + match_prob_thresholds = df_clustered + .metadata["threshold_match_probabilities"] + + # Access the cluster summary stats + cluster_summary_stats = df_clustered.metadata["cluster_summary_stats"] + ``` + + """ + + # Strategy to cluster at multiple thresholds: + # 1. Cluster at the lowest threshold + # 2. At next threshold, note that some clusters do not need to be recomputed. + # Specifically, those where the min probability within the cluster is + # greater than this next threshold. + # 3. Compute remaining clusters which _are_ affected by the next threshold. + # 4. Repeat for remaining thresholds. + + # Need to get nodes and edges in a format suitable to pass to + # cluster_pairwise_predictions_at_threshold + linker = self._linker + db_api = linker._db_api + + pipeline = CTEPipeline() + + enqueue_df_concat(linker, pipeline) + + uid_cols = linker._settings_obj.column_info_settings.unique_id_input_columns + uid_concat_edges_l = _composite_unique_id_from_edges_sql(uid_cols, "l") + uid_concat_edges_r = _composite_unique_id_from_edges_sql(uid_cols, "r") + uid_concat_nodes = _composite_unique_id_from_nodes_sql(uid_cols, None) + + # Input could either be user data, or a SplinkDataFrame + sql = f""" + select + {uid_concat_nodes} as node_id + from __splink__df_concat + """ + pipeline.enqueue_sql(sql, "__splink__df_nodes_with_composite_ids") + + nodes_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + has_match_prob_col = "match_probability" in [ + c.unquote().name for c in df_predict.columns + ] + + is_match_weight = ( + threshold_match_weights is not None + and threshold_match_probabilities is None + ) + + threshold_match_probabilities = threshold_args_to_match_prob_list( + threshold_match_probabilities, threshold_match_weights + ) + + if ( + threshold_match_probabilities is None + or len(threshold_match_probabilities) == 0 + ): + raise ValueError( + "Must provide either threshold_match_probabilities " + "or threshold_match_weights" + ) + + if not has_match_prob_col and threshold_match_probabilities is not None: + raise ValueError( + "df_predict must have a column called 'match_probability' if " + "threshold_match_probability is provided" + ) + + initial_threshold = threshold_match_probabilities.pop(0) + all_results = {} + all_results_summary = {} + + match_p_expr = "" + match_p_select_expr = "" + if initial_threshold is not None: + match_p_expr = f"where match_probability >= {initial_threshold}" + match_p_select_expr = ", match_probability" + + pipeline = CTEPipeline([df_predict]) + + # Templated name must be used here because it could be the output + # of a deterministic link i.e. the physical name is not know for sure + sql = f""" + select + {uid_concat_edges_l} as node_id_l, + {uid_concat_edges_r} as node_id_r + {match_p_select_expr} + from {df_predict.templated_name} + {match_p_expr} + """ + pipeline.enqueue_sql(sql, "__splink__df_edges_from_predict") + + edges_table_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe( + pipeline + ) + logger.info(f"--------Clustering at threshold {initial_threshold}--------") + # First cluster at the lowest threshold + cc = cluster_pairwise_predictions_at_threshold( + nodes=nodes_with_composite_ids, + edges=edges_table_with_composite_ids, + db_api=db_api, + node_id_column_name="node_id", + edge_id_column_name_left="node_id_l", + edge_id_column_name_right="node_id_r", + threshold_match_probability=initial_threshold, + ) + + all_results[initial_threshold] = cc + + # Calculate Summary stats for first clustering threshold + pipeline = CTEPipeline([cc]) + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results_summary[initial_threshold] = cc_summary + + # Now iterate over the remaining thresholds + previous_threshold = initial_threshold + for new_threshold in threshold_match_probabilities: + # Get stable nodes + logger.info(f"--------Clustering at threshold {new_threshold}--------") + pipeline = CTEPipeline([cc, edges_table_with_composite_ids]) + + sqls = _calculate_stable_clusters_at_new_threshold( + edges_sdf=edges_table_with_composite_ids, + cc=cc, + node_id_column_name="node_id", + edge_id_column_name_left="node_id_l", + edge_id_column_name_right="node_id_r", + previous_threshold_match_probability=previous_threshold, + new_threshold_match_probability=new_threshold, + ) + + pipeline.enqueue_list_of_sqls(sqls) + stable_clusters = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + # Derive nodes in play and edges in play by removing stable nodes. Then + # run cluster_pairwise_predictions_at_threshold at new threhold + + pipeline = CTEPipeline([nodes_with_composite_ids, stable_clusters]) + sql = f""" + select * + from {nodes_with_composite_ids.templated_name} + where node_id not in + (select node_id from {stable_clusters.templated_name}) + """ + pipeline.enqueue_sql(sql, "__splink__nodes_in_play") + nodes_in_play = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + pipeline = CTEPipeline([nodes_in_play, edges_table_with_composite_ids]) + sql = f""" + select * + from {edges_table_with_composite_ids.templated_name} + where node_id_l in + (select node_id from {nodes_in_play.templated_name}) + and node_id_r in + (select node_id from {nodes_in_play.templated_name}) + """ + pipeline.enqueue_sql(sql, "__splink__edges_in_play") + edges_in_play = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + marginal_new_clusters = cluster_pairwise_predictions_at_threshold( + nodes_in_play, + edges_in_play, + node_id_column_name="node_id", + edge_id_column_name_left="node_id_l", + edge_id_column_name_right="node_id_r", + db_api=db_api, + threshold_match_probability=new_threshold, + ) + + pipeline = CTEPipeline([stable_clusters, marginal_new_clusters]) + sql = f""" + SELECT node_id, cluster_id + FROM {stable_clusters.templated_name} + UNION ALL + SELECT node_id, cluster_id + FROM {marginal_new_clusters.templated_name} + """ + + pipeline.enqueue_sql(sql, "__splink__clusters_at_threshold") + + cc = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + previous_threshold = new_threshold + + edges_in_play.drop_table_from_database_and_remove_from_cache() + nodes_in_play.drop_table_from_database_and_remove_from_cache() + stable_clusters.drop_table_from_database_and_remove_from_cache() + marginal_new_clusters.drop_table_from_database_and_remove_from_cache() + + all_results[new_threshold] = cc + + # Calculate summary stats for metadata + pipeline = CTEPipeline([cc]) + sqls = _get_cluster_stats_sql(cc) + pipeline.enqueue_list_of_sqls(sqls) + cc_summary = db_api.sql_pipeline_to_splink_dataframe(pipeline) + all_results_summary[new_threshold] = cc_summary + + sql = _generate_detailed_cluster_comparison_sql( + all_results, + unique_id_col="node_id", + is_match_weight=is_match_weight, + ) + pipeline = CTEPipeline() + pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds") + joined = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + pipeline = CTEPipeline() + concat = compute_df_concat(linker, pipeline) + + columns = concat_table_column_names(self._linker) + # don't want to include salting column in output if present + columns_without_salt = filter(lambda x: x != "__splink_salt", columns) + + select_columns_sql = ", ".join(columns_without_salt) + + pipeline = CTEPipeline([joined, concat]) + sql = f""" + select + co.*, + {select_columns_sql} + from {joined.physical_name} as co + left join {concat.physical_name} as c + on co.node_id = {uid_concat_nodes} + """ + pipeline.enqueue_sql( + sql, "__splink__clusters_at_all_thresholds_with_input_data" + ) + + df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline) + + # Add metadata to the output dataframe + ## Match probability thresholds + df_clustered_with_input_data.metadata["threshold_match_probabilities"] = [ + initial_threshold + ] + threshold_match_probabilities + + ## Cluster Summary stats + pipeline = CTEPipeline() + sql = _generate_cluster_summary_stats_sql(all_results_summary) + pipeline.enqueue_sql(sql, "__splink__cluster_summary_stats") + df_clustered_with_input_data.metadata["cluster_summary_stats"] = ( + db_api.sql_pipeline_to_splink_dataframe(pipeline) + ) + + # Drop cached tables + for v in all_results.values(): + v.drop_table_from_database_and_remove_from_cache() + cc.drop_table_from_database_and_remove_from_cache() + + return df_clustered_with_input_data + def _compute_metrics_nodes( self, df_predict: SplinkDataFrame, @@ -500,10 +813,13 @@ def compute_graph_metrics( Returns: GraphMetricsResult: A data class containing SplinkDataFrames - of cluster IDs and selected node, edge or cluster metrics. - attribute "nodes" for nodes metrics table - attribute "edges" for edge metrics table - attribute "clusters" for cluster metrics table + of cluster IDs and selected node, edge or cluster metrics. + + - attribute "nodes" for nodes metrics table + + - attribute "edges" for edge metrics table + + - attribute "clusters" for cluster metrics table Examples: ```python @@ -518,6 +834,9 @@ def compute_graph_metrics( node_metrics = graph_metrics.nodes.as_pandas_dataframe() edge_metrics = graph_metrics.edges.as_pandas_dataframe() cluster_metrics = graph_metrics.clusters.as_pandas_dataframe() + + # Access the match probability thresholds + probability_threshold = graph_metrics.nodes.metadata ``` """ if threshold_match_probability is None: diff --git a/tests/test_clustering.py b/tests/test_clustering.py index a15f6dc1d3..a262d0d71f 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -4,6 +4,9 @@ import splink.comparison_library as cl from splink import DuckDBAPI, Linker, SettingsCreator, block_on +from splink.clustering import ( + cluster_pairwise_predictions_at_multiple_thresholds +) from .basic_settings import get_settings_dict from .decorator import mark_with_dialects_excluding @@ -126,3 +129,106 @@ def test_clustering_no_edges(test_helpers, dialect): # due to blocking rules, df_predict will be empty df_predict = linker.inference.predict() linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95) + +@mark_with_dialects_excluding() +@mark.parametrize( + ["link_type", "input_pd_tables"], + [ + ["dedupe_only", [df]], + ["link_only", [df, df]], # no source dataset + ["link_only", [df_l, df_r]], # source dataset column + ["link_only", [df_combined]], # concatenated frame + ["link_only", [df_l, df_m, df_r]], + ["link_and_dedupe", [df, df]], # no source dataset + ["link_and_dedupe", [df_l, df_r]], # source dataset column + ["link_and_dedupe", [df_combined]], # concatenated frame + ], + ids=[ + "dedupe", + "link_only_no_source_dataset", + "link_only_with_source_dataset", + "link_only_concat", + "link_only_three_tables", + "link_and_dedupe_no_source_dataset", + "link_and_dedupe_with_source_dataset", + "link_and_dedupe_concat", + ], +) +def test_clustering_multi_thresholds(test_helpers, dialect, link_type, input_pd_tables): + helper = test_helpers[dialect] + + settings = SettingsCreator( + link_type=link_type, + comparisons=[ + cl.ExactMatch("first_name"), + cl.ExactMatch("surname"), + cl.ExactMatch("dob"), + cl.ExactMatch("city"), + ], + blocking_rules_to_generate_predictions=[ + block_on("surname"), + block_on("dob"), + ], + ) + linker_input = list(map(helper.convert_frame, input_pd_tables)) + linker = Linker(linker_input, settings, **helper.extra_linker_args()) + + df_predict = linker.inference.predict() + linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95]) + + +def test_clustering_single_multi_threshold_equivalence(): + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + db_api = DuckDBAPI() + settings_dict = get_settings_dict() + linker = Linker(df, settings_dict, db_api=db_api) + + df_predict = linker.inference.predict() + + clusters_0_5 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.5).as_pandas_dataframe() + clusters_0_9 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.9).as_pandas_dataframe() + + clusters_multi = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.9]).as_pandas_dataframe() + + df = pd.merge(clusters_0_5, clusters_multi, left_on='unique_id', right_on='unique_id', how='inner') + + df["different"] = df["cluster_id"] != df["cluster_p_0_9"] + compare = df[["cluster_id", "cluster_p_0_9", "different"]] + df.sort_values(by='different', ascending=False, inplace=True) + print(compare[compare["different"]==True]) + print(sum(compare["different"])) + + assert clusters_0_5["cluster_id"].equals(clusters_multi["cluster_p_0_5"]) + assert clusters_0_9["cluster_id"].equals(clusters_multi["cluster_p_0_9"]) + + + +def test_clustering_multi_threshold_linker_non_linker_equivalence(): + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + db_api = DuckDBAPI() + settings_dict = get_settings_dict() + linker = Linker(df, settings_dict, db_api=db_api) + + df_predict = linker.inference.predict() + + clusters_linker = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( + df_predict, + [0.5, 0.95] + ).as_pandas_dataframe() + clusters_non_linker = cluster_pairwise_predictions_at_multiple_thresholds( + df, + df_predict, + node_id_column_name="unique_id", + edge_id_column_name_left="unique_id_l", + edge_id_column_name_right="unique_id_r", + db_api=linker._db_api, + match_probability_thresholds=[0.5, 0.95] + ).as_pandas_dataframe() + df = pd.DataFrame({'linker': clusters_linker['cluster_p_0_5'], 'non-linker': clusters_non_linker['cluster_p_0_5']}) + # df["different"] = df["linker"] != df["non-linker"] + # df.sort_values(by='different', ascending=False, inplace=True) + # print(df) + # print(sum(df["different"])) + + #assert clusters_linker["cluster_p_0_5"].equals(clusters_non_linker["cluster_p_0_5"]) + #assert clusters_linker["cluster_p_0_95"].equals(clusters_non_linker["cluster_p_0_95"]) \ No newline at end of file