Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Nov 25, 2024
1 parent 14295a7 commit 50f6000
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions tests/test_cluster_at_multiple_thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,92 @@ def test_cluster_at_multiple_thresholds(test_helpers, dialect, graph_size):
multi_threshold_result.columns = ["unique_id", "cluster_id"]

pd.testing.assert_frame_equal(multi_threshold_result, single_threshold_result)


@mark_with_dialects_excluding()
def test_cluster_at_multiple_thresholds_mw_prob_equivalence(test_helpers, dialect):
helper = test_helpers[dialect]
db_api = helper.DatabaseAPI(**helper.db_api_args())

nodes = [
{"my_id": 1},
{"my_id": 2},
{"my_id": 3},
{"my_id": 4},
{"my_id": 5},
{"my_id": 6},
]

edges = [
{"my_id_l": 1, "my_id_r": 2, "match_probability": 0.8},
{"my_id_l": 3, "my_id_r": 2, "match_probability": 0.9},
{"my_id_l": 4, "my_id_r": 5, "match_probability": 0.99},
]

threshold_probabilities = [0.5, 0.7, 0.95]
thresholds_weights = [0.0, 1.22, 4.25]

cc_prob = cluster_pairwise_predictions_at_multiple_thresholds(
nodes,
edges,
node_id_column_name="my_id",
db_api=db_api,
match_probability_thresholds=threshold_probabilities,
output_cluster_summary_stats=False,
)

cc_prob_pd = cc_prob.as_pandas_dataframe()

cc_weight = cluster_pairwise_predictions_at_multiple_thresholds(
nodes,
edges,
node_id_column_name="my_id",
db_api=db_api,
match_weight_thresholds=thresholds_weights,
output_cluster_summary_stats=False,
)

cc_weight_pd = cc_weight.as_pandas_dataframe()

assert "cluster_mw_0" in cc_weight_pd.columns
assert "cluster_mw_1_22" in cc_weight_pd.columns
assert "cluster_mw_4_25" in cc_weight_pd.columns

cc_prob_pd = cc_prob_pd.reset_index(drop=True)
cc_weight_pd = cc_weight_pd.reset_index(drop=True)
cc_weight_pd.columns = cc_prob_pd.columns

pd.testing.assert_frame_equal(cc_prob_pd, cc_weight_pd)

cc_prob_summary = cluster_pairwise_predictions_at_multiple_thresholds(
nodes,
edges,
node_id_column_name="my_id",
db_api=db_api,
match_probability_thresholds=threshold_probabilities,
output_cluster_summary_stats=True,
)

cc_prob_summary_pd = cc_prob_summary.as_pandas_dataframe()

cc_weight_summary = cluster_pairwise_predictions_at_multiple_thresholds(
nodes,
edges,
node_id_column_name="my_id",
db_api=db_api,
match_weight_thresholds=thresholds_weights,
output_cluster_summary_stats=True,
)

cc_weight_summary_pd = cc_weight_summary.as_pandas_dataframe()

# Check that num_clusters max_cluster_size avg_cluster_size contain same values
pd.testing.assert_series_equal(
cc_prob_summary_pd["num_clusters"], cc_weight_summary_pd["num_clusters"]
)
pd.testing.assert_series_equal(
cc_prob_summary_pd["max_cluster_size"], cc_weight_summary_pd["max_cluster_size"]
)
pd.testing.assert_series_equal(
cc_prob_summary_pd["avg_cluster_size"], cc_weight_summary_pd["avg_cluster_size"]
)

0 comments on commit 50f6000

Please sign in to comment.