Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compare two records #2498

Merged
merged 22 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions splink/internals/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if TYPE_CHECKING:
from splink.internals.linker import Linker
from splink.internals.settings import Settings


def truth_space_table_from_labels_with_predictions_sqls(
Expand Down Expand Up @@ -289,8 +290,8 @@ def truth_space_table_from_labels_with_predictions_sqls(
return sqls


def _select_found_by_blocking_rules(linker: "Linker") -> str:
brs = linker._settings_obj._blocking_rules_to_generate_predictions
def _select_found_by_blocking_rules(settings_obj: "Settings") -> str:
brs = settings_obj._blocking_rules_to_generate_predictions

if brs:
br_strings = [
Expand Down Expand Up @@ -425,7 +426,7 @@ def predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename):
)

sqls.extend(sqls_2)
br_col = _select_found_by_blocking_rules(linker)
br_col = _select_found_by_blocking_rules(linker._settings_obj)

sql = f"""
select *, {br_col}
Expand Down
110 changes: 73 additions & 37 deletions splink/internals/linker_components/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from typing import TYPE_CHECKING, Any

from splink.internals.accuracy import _select_found_by_blocking_rules
from splink.internals.blocking import (
BlockingRule,
block_using_rules_sqls,
Expand Down Expand Up @@ -639,16 +640,31 @@ def find_matches_to_new_records(
return predictions

def compare_two_records(
self, record_1: dict[str, Any], record_2: dict[str, Any]
self,
record_1: dict[str, Any] | AcceptableInputTableType,
record_2: dict[str, Any] | AcceptableInputTableType,
include_found_by_blocking_rules: bool = False,
) -> SplinkDataFrame:
"""Use the linkage model to compare and score a pairwise record comparison
based on the two input records provided
based on the two input records provided.

If your inputs contain multiple rows, scores for the cartesian product of
the two inputs will be returned.

If your inputs contain hardcoded term frequency columns (e.g.
a tf_first_name column), then these values will be used instead of any
provided term frequency lookup tables. or term frequency values derived
from the input data.

Args:
record_1 (dict): dictionary representing the first record. Columns names
and data types must be the same as the columns in the settings object
record_2 (dict): dictionary representing the second record. Columns names
and data types must be the same as the columns in the settings object
include_found_by_blocking_rules (bool, optional): If True, outputs a column
indicating whether the record pair would have been found by any of the
blocking rules specified in
settings.blocking_rules_to_generate_predictions. Defaults to False.

Examples:
```py
Expand Down Expand Up @@ -683,30 +699,39 @@ def compare_two_records(
SplinkDataFrame: Pairwise comparison with scored prediction
"""

cache = self._linker._intermediate_table_cache
linker = self._linker

retain_matching_columns = linker._settings_obj._retain_matching_columns
retain_intermediate_calculation_columns = (
linker._settings_obj._retain_intermediate_calculation_columns
)
linker._settings_obj._retain_matching_columns = True
linker._settings_obj._retain_intermediate_calculation_columns = True

cache = linker._intermediate_table_cache

uid = ascii_uid(8)

# Check if input is a DuckDB relation without importing DuckDB
if isinstance(record_1, dict):
to_register_left = [record_1]
to_register_left: AcceptableInputTableType = [record_1]
else:
to_register_left = record_1

if isinstance(record_2, dict):
to_register_right = [record_2]
to_register_right: AcceptableInputTableType = [record_2]
else:
to_register_right = record_2

df_records_left = self._linker.table_management.register_table(
df_records_left = linker.table_management.register_table(
to_register_left,
f"__splink__compare_two_records_left_{uid}",
overwrite=True,
)

df_records_left.templated_name = "__splink__compare_two_records_left"

df_records_right = self._linker.table_management.register_table(
df_records_right = linker.table_management.register_table(
to_register_right,
f"__splink__compare_two_records_right_{uid}",
overwrite=True,
Expand All @@ -719,7 +744,9 @@ def compare_two_records(
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)

for tf_col in self._linker._settings_obj._term_frequency_columns:
tf_cols = linker._settings_obj._term_frequency_columns

for tf_col in tf_cols:
tf_table_name = colname_to_tf_tablename(tf_col)
if tf_table_name in cache:
tf_table = cache.get_with_logging(tf_table_name)
Expand All @@ -734,67 +761,76 @@ def compare_two_records(
)

sql_join_tf = _join_new_table_to_df_concat_with_tf_sql(
self._linker, "__splink__compare_two_records_left"
linker, "__splink__compare_two_records_left", df_records_left
)

pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_left_with_tf")

sql_join_tf = _join_new_table_to_df_concat_with_tf_sql(
self._linker, "__splink__compare_two_records_right"
linker, "__splink__compare_two_records_right", df_records_right
)

pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf")

source_dataset_ic = (
self._linker._settings_obj.column_info_settings.source_dataset_input_column
)
uid_ic = self._linker._settings_obj.column_info_settings.unique_id_input_column

pipeline = add_unique_id_and_source_dataset_cols_if_needed(
self._linker,
linker,
df_records_left,
pipeline,
in_tablename="__splink__compare_two_records_left_with_tf",
out_tablename="__splink__compare_two_records_left_with_tf_uid_fix",
uid_str="_left",
)
pipeline = add_unique_id_and_source_dataset_cols_if_needed(
self._linker,
linker,
df_records_right,
pipeline,
in_tablename="__splink__compare_two_records_right_with_tf",
out_tablename="__splink__compare_two_records_right_with_tf_uid_fix",
uid_str="_right",
)

sqls = block_using_rules_sqls(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using cartesian blocking, we don't need to run any complex blocking code.

In addition, this code creates and materilises a list of pairwise Ids, which is uses for the join. This is unnecessary in the context of a handful of records

input_tablename_l="__splink__compare_two_records_left_with_tf_uid_fix",
input_tablename_r="__splink__compare_two_records_right_with_tf_uid_fix",
blocking_rules=[BlockingRule("1=1")],
link_type=self._linker._settings_obj._link_type,
source_dataset_input_column=source_dataset_ic,
unique_id_input_column=uid_ic,
)
pipeline.enqueue_list_of_sqls(sqls)
cols_to_select = self._linker._settings_obj._columns_to_select_for_blocking

sqls = compute_comparison_vector_values_from_id_pairs_sqls(
self._linker._settings_obj._columns_to_select_for_blocking,
self._linker._settings_obj._columns_to_select_for_comparison_vector_values,
input_tablename_l="__splink__compare_two_records_left_with_tf_uid_fix",
input_tablename_r="__splink__compare_two_records_right_with_tf_uid_fix",
source_dataset_input_column=source_dataset_ic,
unique_id_input_column=uid_ic,
select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}, 0 as match_key
from __splink__compare_two_records_left_with_tf_uid_fix as l
cross join __splink__compare_two_records_right_with_tf_uid_fix as r
"""
pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked")

cols_to_select = (
linker._settings_obj._columns_to_select_for_comparison_vector_values
)
pipeline.enqueue_list_of_sqls(sqls)
select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}
from __splink__compare_two_records_blocked
"""
pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors")

sqls = predict_from_comparison_vectors_sqls_using_settings(
self._linker._settings_obj,
sql_infinity_expression=self._linker._infinity_expression,
linker._settings_obj,
sql_infinity_expression=linker._infinity_expression,
)
pipeline.enqueue_list_of_sqls(sqls)

predictions = self._linker._db_api.sql_pipeline_to_splink_dataframe(
if include_found_by_blocking_rules:
br_col = _select_found_by_blocking_rules(linker._settings_obj)
sql = f"""
select *, {br_col}
from __splink__df_predict
"""

pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules")

predictions = linker._db_api.sql_pipeline_to_splink_dataframe(
pipeline, use_cache=False
)

linker._settings_obj._retain_matching_columns = retain_matching_columns
linker._settings_obj._retain_intermediate_calculation_columns = (
retain_intermediate_calculation_columns
)

return predictions
142 changes: 142 additions & 0 deletions splink/internals/realtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

from splink.internals.accuracy import _select_found_by_blocking_rules
from splink.internals.database_api import AcceptableInputTableType, DatabaseAPISubClass
from splink.internals.misc import ascii_uid
from splink.internals.pipeline import CTEPipeline
from splink.internals.predict import (
predict_from_comparison_vectors_sqls_using_settings,
)
from splink.internals.settings_creator import SettingsCreator
from splink.internals.splink_dataframe import SplinkDataFrame


class SQLCache:
def __init__(self):
self._cache: dict[int, tuple[str, str | None]] = {}

def get(self, settings_id: int, new_uid: str) -> str | None:
if settings_id not in self._cache:
return None

sql, cached_uid = self._cache[settings_id]
if cached_uid:
sql = sql.replace(cached_uid, new_uid)
return sql

def set(self, settings_id: int, sql: str | None, uid: str | None) -> None:
if sql is not None:
self._cache[settings_id] = (sql, uid)


_sql_cache = SQLCache()


def compare_records(
record_1: dict[str, Any] | AcceptableInputTableType,
record_2: dict[str, Any] | AcceptableInputTableType,
settings: SettingsCreator | dict[str, Any] | Path | str,
db_api: DatabaseAPISubClass,
use_sql_from_cache: bool = True,
include_found_by_blocking_rules: bool = False,
) -> SplinkDataFrame:
"""Compare two records and compute similarity scores without requiring a Linker.
Assumes any required term frequency values are provided in the input records.

Args:
record_1 (dict): First record to compare
record_2 (dict): Second record to compare
db_api (DatabaseAPISubClass): Database API to use for computations

Returns:
SplinkDataFrame: Comparison results
"""
global _sql_cache

uid = ascii_uid(8)

if isinstance(record_1, dict):
to_register_left: AcceptableInputTableType = [record_1]
else:
to_register_left = record_1

if isinstance(record_2, dict):
to_register_right: AcceptableInputTableType = [record_2]
else:
to_register_right = record_2

df_records_left = db_api.register_table(
to_register_left,
f"__splink__compare_records_left_{uid}",
overwrite=True,
)
df_records_left.templated_name = "__splink__compare_records_left"

df_records_right = db_api.register_table(
to_register_right,
f"__splink__compare_records_right_{uid}",
overwrite=True,
)
df_records_right.templated_name = "__splink__compare_records_right"

settings_id = id(settings)
if use_sql_from_cache:
if cached_sql := _sql_cache.get(settings_id, uid):
return db_api._sql_to_splink_dataframe(
cached_sql,
templated_name="__splink__realtime_compare_records",
physical_name=f"__splink__realtime_compare_records_{uid}",
)

if not isinstance(settings, SettingsCreator):
settings_creator = SettingsCreator.from_path_or_dict(settings)
else:
settings_creator = settings

settings_obj = settings_creator.get_settings(db_api.sql_dialect.sql_dialect_str)

settings_obj._retain_matching_columns = True
settings_obj._retain_intermediate_calculation_columns = True

pipeline = CTEPipeline([df_records_left, df_records_right])

cols_to_select = settings_obj._columns_to_select_for_blocking

select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}, 0 as match_key
from __splink__compare_records_left as l
cross join __splink__compare_records_right as r
"""
pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked")

cols_to_select = settings_obj._columns_to_select_for_comparison_vector_values
select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}
from __splink__compare_two_records_blocked
"""
pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors")

sqls = predict_from_comparison_vectors_sqls_using_settings(
settings_obj,
sql_infinity_expression=db_api.sql_dialect.infinity_expression,
)
pipeline.enqueue_list_of_sqls(sqls)

if include_found_by_blocking_rules:
br_col = _select_found_by_blocking_rules(settings_obj)
sql = f"""
select *, {br_col}
from __splink__df_predict
"""

pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules")

predictions = db_api.sql_pipeline_to_splink_dataframe(pipeline)
_sql_cache.set(settings_id, predictions.sql_used_to_create, uid)

return predictions
Loading
Loading