Skip to content

Commit

Permalink
formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
RossKen committed Feb 11, 2025
1 parent a253d4b commit 3f6d09d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
10 changes: 8 additions & 2 deletions splink/clustering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from .internals.clustering import cluster_pairwise_predictions_at_threshold, cluster_pairwise_predictions_at_multiple_thresholds
from .internals.clustering import (
cluster_pairwise_predictions_at_multiple_thresholds,
cluster_pairwise_predictions_at_threshold
)

__all__ = ["cluster_pairwise_predictions_at_threshold", "cluster_pairwise_predictions_at_multiple_thresholds"]
__all__ = [
"cluster_pairwise_predictions_at_threshold",
"cluster_pairwise_predictions_at_multiple_thresholds"
]
29 changes: 17 additions & 12 deletions splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,24 +358,25 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
Records with an estimated `match_probability` at or above each of the values in
`threshold_match_probabilities` (or records with a `match_weight` at or above
each of the values in `threshold_match_weights`) are considered to be a match
each of the values in `threshold_match_weights`) are considered to be a match
(i.e. they represent the same entity).
This function efficiently computes clusters for multiple thresholds by starting
with the lowest threshold and incrementally updating clusters for higher thresholds.
with the lowest threshold and incrementally updating clusters for higher
thresholds.
Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
threshold_match_probabilities (list[float] | None): List of match probability
thresholds to compute clusters for
threshold_match_weights (list[float] | None): List of match weight thresholds
to compute clusters for
threshold_match_probabilities (list[float] | None): List of match
probability thresholds to compute clusters for
threshold_match_weights (list[float] | None): List of match weight
thresholds to compute clusters for
output_cluster_summary_stats (bool): If True, output summary statistics
for each threshold instead of full cluster information
Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
into groups for each of the desired match thresholds.
into groups for each of the desired match thresholds.
If output_cluster_summary_stats is True, it contains summary
statistics (number of clusters, max cluster size, avg cluster size) for
each threshold.
Expand Down Expand Up @@ -428,19 +429,21 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
]

is_match_weight = (
threshold_match_weights is not None and threshold_match_probabilities is None
threshold_match_weights is not None
and threshold_match_probabilities is None
)

threshold_match_probabilities = threshold_args_to_match_prob_list(
threshold_match_probabilities, threshold_match_weights
)

if threshold_match_probabilities is None or len(threshold_match_probabilities) == 0:
if (threshold_match_probabilities is None
or len(threshold_match_probabilities) == 0):
raise ValueError(
"Must provide either threshold_match_probabilities "
"or threshold_match_weights"
)

if not has_match_prob_col and threshold_match_probabilities is not None:
raise ValueError(
"df_predict must have a column called 'match_probability' if "
Expand Down Expand Up @@ -607,7 +610,9 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
left join __splink__df_concat
on co.node_id = {uid_concat_nodes}
"""
pipeline.enqueue_sql(sql, "__splink__clusters_at_all_thresholds_with_input_data")
pipeline.enqueue_sql(sql,
"__splink__clusters_at_all_thresholds_with_input_data"
)

df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline)

Expand All @@ -618,7 +623,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
df_clustered_with_input_data.metadata["threshold_match_probabilities"] = (
[initial_threshold] + threshold_match_probabilities
)

return df_clustered_with_input_data

def _compute_metrics_nodes(
Expand Down

0 comments on commit 3f6d09d

Please sign in to comment.