Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RossKen committed Feb 13, 2025
1 parent 5b15a6f commit b7a155f
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import splink.comparison_library as cl
from splink import DuckDBAPI, Linker, SettingsCreator, block_on
from splink.clustering import (
cluster_pairwise_predictions_at_multiple_thresholds
)

from .basic_settings import get_settings_dict
from .decorator import mark_with_dialects_excluding
Expand Down Expand Up @@ -126,3 +129,93 @@ def test_clustering_no_edges(test_helpers, dialect):
# due to blocking rules, df_predict will be empty
df_predict = linker.inference.predict()
linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95)

@mark_with_dialects_excluding()
@mark.parametrize(
["link_type", "input_pd_tables"],
[
["dedupe_only", [df]],
["link_only", [df, df]], # no source dataset
["link_only", [df_l, df_r]], # source dataset column
["link_only", [df_combined]], # concatenated frame
["link_only", [df_l, df_m, df_r]],
["link_and_dedupe", [df, df]], # no source dataset
["link_and_dedupe", [df_l, df_r]], # source dataset column
["link_and_dedupe", [df_combined]], # concatenated frame
],
ids=[
"dedupe",
"link_only_no_source_dataset",
"link_only_with_source_dataset",
"link_only_concat",
"link_only_three_tables",
"link_and_dedupe_no_source_dataset",
"link_and_dedupe_with_source_dataset",
"link_and_dedupe_concat",
],
)
def test_clustering_multi_thresholds(test_helpers, dialect, link_type, input_pd_tables):
helper = test_helpers[dialect]

settings = SettingsCreator(
link_type=link_type,
comparisons=[
cl.ExactMatch("first_name"),
cl.ExactMatch("surname"),
cl.ExactMatch("dob"),
cl.ExactMatch("city"),
],
blocking_rules_to_generate_predictions=[
block_on("surname"),
block_on("dob"),
],
)
linker_input = list(map(helper.convert_frame, input_pd_tables))
linker = Linker(linker_input, settings, **helper.extra_linker_args())

df_predict = linker.inference.predict()
linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95])


def test_clustering_single_multi_threshold_equivalence():
df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
db_api = DuckDBAPI()
settings_dict = get_settings_dict()
linker = Linker(df, settings_dict, db_api=db_api)

df_predict = linker.inference.predict()

clusters_0_5 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.5).as_pandas_dataframe()
clusters_0_95 = linker.clustering.cluster_pairwise_predictions_at_threshold(df_predict, 0.95).as_pandas_dataframe()

clusters_multi = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(df_predict, [0.5, 0.95]).as_pandas_dataframe()

assert clusters_0_5["cluster_id"] == clusters_multi["cluster_p_0_5"]
assert clusters_0_95["cluster_id"] == clusters_multi["cluster_p_0_95"]



def test_clustering_multi_threshold_linker_non_linker_equivalence():
df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
db_api = DuckDBAPI()
settings_dict = get_settings_dict()
linker = Linker(df, settings_dict, db_api=db_api)

df_predict = linker.inference.predict()

clusters_linker = linker.clustering.cluster_pairwise_predictions_at_multiple_thresholds(
df_predict,
[0.5, 0.95]
).as_pandas_dataframe()
clusters_non_linker = cluster_pairwise_predictions_at_multiple_thresholds(
df,
df_predict,
node_id_column_name="unique_id",
edge_id_column_name_left="unique_id_l",
edge_id_column_name_right="unique_id_r",
db_api=linker.db_api,
match_probability_thresholds=[0.5, 0.95]
).as_pandas_dataframe()

assert clusters_linker["cluster_p_0_5"] == clusters_non_linker["cluster_p_0_5"]
assert clusters_linker["cluster_p_0_95"] == clusters_non_linker["cluster_p_0_95"]

0 comments on commit b7a155f

Please sign in to comment.