diff --git a/tests/test_clustering.py b/tests/test_clustering.py index a15f6dc1d3..0afb7a4266 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -4,6 +4,9 @@ import splink.comparison_library as cl from splink import DuckDBAPI, Linker, SettingsCreator, block_on +from splink.clustering import ( + cluster_pairwise_predictions_at_multiple_thresholds +) from .basic_settings import get_settings_dict from .decorator import mark_with_dialects_excluding @@ -126,3 +129,93 @@ def test_clustering_no_edges(test_helpers, dialect): # due to blocking rules, df_predict will be empty df_predict = linker.inference.predict() linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95) + +@mark_with_dialects_excluding() +@mark.parametrize( + ["link_type", "input_pd_tables"], + [ + ["dedupe_only", [df]], + ["link_only", [df, df]], # no source dataset + ["link_only", [df_l, df_r]], # source dataset column + ["link_only", [df_combined]], # concatenated frame + ["link_only", [df_l, df_m, df_r]], + ["link_and_dedupe", [df, df]], # no source dataset + ["link_and_dedupe", [df_l, df_r]], # source dataset column + ["link_and_dedupe", [df_combined]], # concatenated frame + ], + ids=[ + "dedupe", + "link_only_no_source_dataset", + "link_only_with_source_dataset", + "link_only_concat", + "link_only_three_tables", + "link_and_dedupe_no_source_dataset", + "link_and_dedupe_with_source_dataset", + "link_and_dedupe_concat", + ], +) +def test_clustering_multi_thresholds(test_helpers, dialect, link_type, input_pd_tables): + helper = test_helpers[dialect] + + settings = SettingsCreator( + link_type=link_type, + comparisons=[ + cl.ExactMatch("first_name"), + cl.ExactMatch("surname"), + cl.ExactMatch("dob"), + cl.ExactMatch("city"), + ], + blocking_rules_to_generate_predictions=[ + block_on("surname"), + block_on("dob"), + ], + ) + linker_input = list(map(helper.convert_frame, input_pd_tables)) + linker = Linker(linker_input, settings, **helper.extra_linker_args()) + + df_predict = linker.inference.predict() + linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95]) + + +def test_clustering_single_multi_threshold_equivalence(): + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + db_api = DuckDBAPI() + settings_dict = get_settings_dict() + linker = Linker(df, settings_dict, db_api=db_api) + + df_predict = linker.inference.predict() + + clusters_0_5 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.5).as_pandas_dataframe() + clusters_0_95 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95).as_pandas_dataframe() + + clusters_multi = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95]).as_pandas_dataframe() + + assert clusters_0_5["cluster_id"] == clusters_multi["cluster_p_0_5"] + assert clusters_0_95["cluster_id"] == clusters_multi["cluster_p_0_95"] + + + +def test_clustering_multi_threshold_linker_non_linker_equivalence(): + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + db_api = DuckDBAPI() + settings_dict = get_settings_dict() + linker = Linker(df, settings_dict, db_api=db_api) + + df_predict = linker.inference.predict() + + clusters_linker = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds( + df_predict, + [0.5, 0.95] + ).as_pandas_dataframe() + clusters_non_linker = cluster_pairwise_predictions_at_multiple_thresholds( + df, + df_predict, + node_id_column_name="unique_id", + edge_id_column_name_left="unique_id_l", + edge_id_column_name_right="unique_id_r", + db_api=linker.db_api, + match_probability_thresholds=[0.5, 0.95] + ).as_pandas_dataframe() + + assert clusters_linker["cluster_p_0_5"] == clusters_non_linker["cluster_p_0_5"] + assert clusters_linker["cluster_p_0_95"] == clusters_non_linker["cluster_p_0_95"] \ No newline at end of file