Skip to content

Commit

Permalink
Merge pull request #2578 from aymonwuolanne/master
Browse files Browse the repository at this point in the history
One to one clustering
  • Loading branch information
RobinL authored Feb 11, 2025
2 parents 2e3493a + a800990 commit 44d6f5d
Show file tree
Hide file tree
Showing 5 changed files with 701 additions and 7 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/pytest_duckdb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,10 @@ jobs:
- name: Upload coverage report
uses: actions/upload-artifact@v4
if: ${{ matrix.python-version == '3.11' }}
with:
name: coverage-report
name: coverage-report-${{ matrix.python-version }}
path: coverage.xml
- name: Upload to Codecov
uses: codecov/codecov-action@v3
# upload a single run version - should be representative coverage-wise as we don't have much version-dependent code
if: ${{ matrix.python-version == '3.11' }}
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Support for 'one to one' linking and clustering (allowing the user to force clusters to contain at most one record from given `source_dataset`s) in [#2578](https://github.com/moj-analytical-services/splink/pull/2578/)
- `ColumnExpression` now supports accessing first or last element of an array column via method `access_extreme_array_element()` ([#2585](https://github.com/moj-analytical-services/splink/pull/2585)), or converting string literals to `NULL` via `nullif()` ([#2586](https://github.com/moj-analytical-services/splink/pull/2586))


### Deprecated

- Deprecated support for python `3.8.x` following end of support for that minor version ([#2520](https://github.com/moj-analytical-services/splink/pull/2520))
Expand Down Expand Up @@ -51,7 +53,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Cluster without linker by @RobinL in https://github.com/moj-analytical-services/splink/pull/2412
- Better autocomplete for dataframes by @RobinL in https://github.com/moj-analytical-services/splink/pull/2434


## [4.0.2] - 2024-09-19

### Added
Expand Down
159 changes: 158 additions & 1 deletion splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from splink.internals.connected_components import (
solve_connected_components,
Expand All @@ -14,6 +14,9 @@
from splink.internals.misc import (
threshold_args_to_match_prob,
)
from splink.internals.one_to_one_clustering import (
one_to_one_clustering,
)
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame
from splink.internals.unique_id_concat import (
Expand Down Expand Up @@ -177,6 +180,160 @@ def cluster_pairwise_predictions_at_threshold(

return df_clustered_with_input_data

def cluster_using_single_best_links(
self,
df_predict: SplinkDataFrame,
duplicate_free_datasets: List[str],
threshold_match_probability: Optional[float] = None,
threshold_match_weight: Optional[float] = None,
) -> SplinkDataFrame:
"""
Clusters the pairwise match predictions that result from
`linker.inference.predict()` into groups of connected records using a single
best links method that restricts the clusters to have at most one record from
each source dataset in the `duplicate_free_datasets` list.
This method will include a record into a cluster if it is mutually the best
match for the record and for the cluster, and if adding the record will not
violate the criteria of having at most one record from each of the
`duplicate_free_datasets`.
Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
duplicate_free_datasets: (List[str]): The source datasets which should be
treated as having no duplicates. Clusters will not form with more than
one record from each of these datasets. This can be a subset of all of
the source datasets in the input data.
threshold_match_probability (float, optional): Pairwise comparisons with a
`match_probability` at or above this threshold are matched
threshold_match_weight (float, optional): Pairwise comparisons with a
`match_weight` at or above this threshold are matched. Only one of
threshold_match_probability or threshold_match_weight should be provided
Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
into groups based on the desired match threshold and the source datasets
for which duplicates are not allowed.
Examples:
```python
df_predict = linker.inference.predict(threshold_match_probability=0.5)
df_clustered = linker.clustering.cluster_pairwise_predictions_at_threshold(
df_predict,
duplicate_free_datasets=["A", "B"],
threshold_match_probability=0.95
)
```
"""
linker = self._linker
db_api = linker._db_api

pipeline = CTEPipeline()

enqueue_df_concat(linker, pipeline)

uid_cols = linker._settings_obj.column_info_settings.unique_id_input_columns
uid_concat_edges_l = _composite_unique_id_from_edges_sql(uid_cols, "l")
uid_concat_edges_r = _composite_unique_id_from_edges_sql(uid_cols, "r")
uid_concat_nodes = _composite_unique_id_from_nodes_sql(uid_cols, None)

source_dataset_column_name = (
linker._settings_obj.column_info_settings.source_dataset_column_name
)

sql = f"""
select
{uid_concat_nodes} as node_id,
{source_dataset_column_name} as source_dataset
from __splink__df_concat
"""
pipeline.enqueue_sql(sql, "__splink__df_nodes_with_composite_ids")

nodes_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe(pipeline)

has_match_prob_col = "match_probability" in [
c.unquote().name for c in df_predict.columns
]

threshold_match_probability = threshold_args_to_match_prob(
threshold_match_probability, threshold_match_weight
)

if not has_match_prob_col and threshold_match_probability is not None:
raise ValueError(
"df_predict must have a column called 'match_probability' if "
"threshold_match_probability is provided"
)

match_p_expr = ""
match_p_select_expr = ""
if threshold_match_probability is not None:
match_p_expr = f"where match_probability >= {threshold_match_probability}"
match_p_select_expr = ", match_probability"

pipeline = CTEPipeline([df_predict])

# Templated name must be used here because it could be the output
# of a deterministic link i.e. the templated name is not know for sure
sql = f"""
select
{uid_concat_edges_l} as node_id_l,
{uid_concat_edges_r} as node_id_r
{match_p_select_expr}
from {df_predict.templated_name}
{match_p_expr}
"""
pipeline.enqueue_sql(sql, "__splink__df_edges_from_predict")

edges_table_with_composite_ids = db_api.sql_pipeline_to_splink_dataframe(
pipeline
)

oo = one_to_one_clustering(
nodes_table=nodes_with_composite_ids,
edges_table=edges_table_with_composite_ids,
node_id_column_name="node_id",
source_dataset_column_name="source_dataset",
edge_id_column_name_left="node_id_l",
edge_id_column_name_right="node_id_r",
duplicate_free_datasets=duplicate_free_datasets,
db_api=db_api,
threshold_match_probability=threshold_match_probability,
)

edges_table_with_composite_ids.drop_table_from_database_and_remove_from_cache()
nodes_with_composite_ids.drop_table_from_database_and_remove_from_cache()
pipeline = CTEPipeline([oo])

enqueue_df_concat(linker, pipeline)

columns = concat_table_column_names(self._linker)
# don't want to include salting column in output if present
columns_without_salt = filter(lambda x: x != "__splink_salt", columns)

select_columns_sql = ", ".join(columns_without_salt)

sql = f"""
select
oo.cluster_id,
{select_columns_sql}
from {oo.templated_name} as oo
left join __splink__df_concat
on oo.node_id = {uid_concat_nodes}
"""
pipeline.enqueue_sql(sql, "__splink__df_clustered_with_input_data")

df_clustered_with_input_data = db_api.sql_pipeline_to_splink_dataframe(pipeline)

oo.drop_table_from_database_and_remove_from_cache()

if threshold_match_probability is not None:
df_clustered_with_input_data.metadata["threshold_match_probability"] = (
threshold_match_probability
)

return df_clustered_with_input_data

def _compute_metrics_nodes(
self,
df_predict: SplinkDataFrame,
Expand Down
Loading

0 comments on commit 44d6f5d

Please sign in to comment.