diff --git a/splink/internals/accuracy.py b/splink/internals/accuracy.py index b7d342b5fa..7ec766b5cf 100644 --- a/splink/internals/accuracy.py +++ b/splink/internals/accuracy.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from splink.internals.linker import Linker + from splink.internals.settings import Settings def truth_space_table_from_labels_with_predictions_sqls( @@ -289,8 +290,8 @@ def truth_space_table_from_labels_with_predictions_sqls( return sqls -def _select_found_by_blocking_rules(linker: "Linker") -> str: - brs = linker._settings_obj._blocking_rules_to_generate_predictions +def _select_found_by_blocking_rules(settings_obj: "Settings") -> str: + brs = settings_obj._blocking_rules_to_generate_predictions if brs: br_strings = [ @@ -425,7 +426,7 @@ def predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename): ) sqls.extend(sqls_2) - br_col = _select_found_by_blocking_rules(linker) + br_col = _select_found_by_blocking_rules(linker._settings_obj) sql = f""" select *, {br_col} diff --git a/splink/internals/linker_components/inference.py b/splink/internals/linker_components/inference.py index d147714947..035dfcfe85 100644 --- a/splink/internals/linker_components/inference.py +++ b/splink/internals/linker_components/inference.py @@ -4,6 +4,7 @@ import time from typing import TYPE_CHECKING, Any +from splink.internals.accuracy import _select_found_by_blocking_rules from splink.internals.blocking import ( BlockingRule, block_using_rules_sqls, @@ -639,16 +640,31 @@ def find_matches_to_new_records( return predictions def compare_two_records( - self, record_1: dict[str, Any], record_2: dict[str, Any] + self, + record_1: dict[str, Any] | AcceptableInputTableType, + record_2: dict[str, Any] | AcceptableInputTableType, + include_found_by_blocking_rules: bool = False, ) -> SplinkDataFrame: """Use the linkage model to compare and score a pairwise record comparison - based on the two input records provided + based on the two input records provided. + + If your inputs contain multiple rows, scores for the cartesian product of + the two inputs will be returned. + + If your inputs contain hardcoded term frequency columns (e.g. + a tf_first_name column), then these values will be used instead of any + provided term frequency lookup tables. or term frequency values derived + from the input data. Args: record_1 (dict): dictionary representing the first record. Columns names and data types must be the same as the columns in the settings object record_2 (dict): dictionary representing the second record. Columns names and data types must be the same as the columns in the settings object + include_found_by_blocking_rules (bool, optional): If True, outputs a column + indicating whether the record pair would have been found by any of the + blocking rules specified in + settings.blocking_rules_to_generate_predictions. Defaults to False. Examples: ```py @@ -683,22 +699,31 @@ def compare_two_records( SplinkDataFrame: Pairwise comparison with scored prediction """ - cache = self._linker._intermediate_table_cache + linker = self._linker + + retain_matching_columns = linker._settings_obj._retain_matching_columns + retain_intermediate_calculation_columns = ( + linker._settings_obj._retain_intermediate_calculation_columns + ) + linker._settings_obj._retain_matching_columns = True + linker._settings_obj._retain_intermediate_calculation_columns = True + + cache = linker._intermediate_table_cache uid = ascii_uid(8) # Check if input is a DuckDB relation without importing DuckDB if isinstance(record_1, dict): - to_register_left = [record_1] + to_register_left: AcceptableInputTableType = [record_1] else: to_register_left = record_1 if isinstance(record_2, dict): - to_register_right = [record_2] + to_register_right: AcceptableInputTableType = [record_2] else: to_register_right = record_2 - df_records_left = self._linker.table_management.register_table( + df_records_left = linker.table_management.register_table( to_register_left, f"__splink__compare_two_records_left_{uid}", overwrite=True, @@ -706,7 +731,7 @@ def compare_two_records( df_records_left.templated_name = "__splink__compare_two_records_left" - df_records_right = self._linker.table_management.register_table( + df_records_right = linker.table_management.register_table( to_register_right, f"__splink__compare_two_records_right_{uid}", overwrite=True, @@ -719,7 +744,9 @@ def compare_two_records( nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf") pipeline.append_input_dataframe(nodes_with_tf) - for tf_col in self._linker._settings_obj._term_frequency_columns: + tf_cols = linker._settings_obj._term_frequency_columns + + for tf_col in tf_cols: tf_table_name = colname_to_tf_tablename(tf_col) if tf_table_name in cache: tf_table = cache.get_with_logging(tf_table_name) @@ -734,24 +761,19 @@ def compare_two_records( ) sql_join_tf = _join_new_table_to_df_concat_with_tf_sql( - self._linker, "__splink__compare_two_records_left" + linker, "__splink__compare_two_records_left", df_records_left ) pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_left_with_tf") sql_join_tf = _join_new_table_to_df_concat_with_tf_sql( - self._linker, "__splink__compare_two_records_right" + linker, "__splink__compare_two_records_right", df_records_right ) pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf") - source_dataset_ic = ( - self._linker._settings_obj.column_info_settings.source_dataset_input_column - ) - uid_ic = self._linker._settings_obj.column_info_settings.unique_id_input_column - pipeline = add_unique_id_and_source_dataset_cols_if_needed( - self._linker, + linker, df_records_left, pipeline, in_tablename="__splink__compare_two_records_left_with_tf", @@ -759,7 +781,7 @@ def compare_two_records( uid_str="_left", ) pipeline = add_unique_id_and_source_dataset_cols_if_needed( - self._linker, + linker, df_records_right, pipeline, in_tablename="__splink__compare_two_records_right_with_tf", @@ -767,34 +789,48 @@ def compare_two_records( uid_str="_right", ) - sqls = block_using_rules_sqls( - input_tablename_l="__splink__compare_two_records_left_with_tf_uid_fix", - input_tablename_r="__splink__compare_two_records_right_with_tf_uid_fix", - blocking_rules=[BlockingRule("1=1")], - link_type=self._linker._settings_obj._link_type, - source_dataset_input_column=source_dataset_ic, - unique_id_input_column=uid_ic, - ) - pipeline.enqueue_list_of_sqls(sqls) + cols_to_select = self._linker._settings_obj._columns_to_select_for_blocking - sqls = compute_comparison_vector_values_from_id_pairs_sqls( - self._linker._settings_obj._columns_to_select_for_blocking, - self._linker._settings_obj._columns_to_select_for_comparison_vector_values, - input_tablename_l="__splink__compare_two_records_left_with_tf_uid_fix", - input_tablename_r="__splink__compare_two_records_right_with_tf_uid_fix", - source_dataset_input_column=source_dataset_ic, - unique_id_input_column=uid_ic, + select_expr = ", ".join(cols_to_select) + sql = f""" + select {select_expr}, 0 as match_key + from __splink__compare_two_records_left_with_tf_uid_fix as l + cross join __splink__compare_two_records_right_with_tf_uid_fix as r + """ + pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked") + + cols_to_select = ( + linker._settings_obj._columns_to_select_for_comparison_vector_values ) - pipeline.enqueue_list_of_sqls(sqls) + select_expr = ", ".join(cols_to_select) + sql = f""" + select {select_expr} + from __splink__compare_two_records_blocked + """ + pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors") sqls = predict_from_comparison_vectors_sqls_using_settings( - self._linker._settings_obj, - sql_infinity_expression=self._linker._infinity_expression, + linker._settings_obj, + sql_infinity_expression=linker._infinity_expression, ) pipeline.enqueue_list_of_sqls(sqls) - predictions = self._linker._db_api.sql_pipeline_to_splink_dataframe( + if include_found_by_blocking_rules: + br_col = _select_found_by_blocking_rules(linker._settings_obj) + sql = f""" + select *, {br_col} + from __splink__df_predict + """ + + pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules") + + predictions = linker._db_api.sql_pipeline_to_splink_dataframe( pipeline, use_cache=False ) + linker._settings_obj._retain_matching_columns = retain_matching_columns + linker._settings_obj._retain_intermediate_calculation_columns = ( + retain_intermediate_calculation_columns + ) + return predictions diff --git a/splink/internals/realtime.py b/splink/internals/realtime.py new file mode 100644 index 0000000000..168b838759 --- /dev/null +++ b/splink/internals/realtime.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from splink.internals.accuracy import _select_found_by_blocking_rules +from splink.internals.database_api import AcceptableInputTableType, DatabaseAPISubClass +from splink.internals.misc import ascii_uid +from splink.internals.pipeline import CTEPipeline +from splink.internals.predict import ( + predict_from_comparison_vectors_sqls_using_settings, +) +from splink.internals.settings_creator import SettingsCreator +from splink.internals.splink_dataframe import SplinkDataFrame + + +class SQLCache: + def __init__(self): + self._cache: dict[int, tuple[str, str | None]] = {} + + def get(self, settings_id: int, new_uid: str) -> str | None: + if settings_id not in self._cache: + return None + + sql, cached_uid = self._cache[settings_id] + if cached_uid: + sql = sql.replace(cached_uid, new_uid) + return sql + + def set(self, settings_id: int, sql: str | None, uid: str | None) -> None: + if sql is not None: + self._cache[settings_id] = (sql, uid) + + +_sql_cache = SQLCache() + + +def compare_records( + record_1: dict[str, Any] | AcceptableInputTableType, + record_2: dict[str, Any] | AcceptableInputTableType, + settings: SettingsCreator | dict[str, Any] | Path | str, + db_api: DatabaseAPISubClass, + use_sql_from_cache: bool = True, + include_found_by_blocking_rules: bool = False, +) -> SplinkDataFrame: + """Compare two records and compute similarity scores without requiring a Linker. + Assumes any required term frequency values are provided in the input records. + + Args: + record_1 (dict): First record to compare + record_2 (dict): Second record to compare + db_api (DatabaseAPISubClass): Database API to use for computations + + Returns: + SplinkDataFrame: Comparison results + """ + global _sql_cache + + uid = ascii_uid(8) + + if isinstance(record_1, dict): + to_register_left: AcceptableInputTableType = [record_1] + else: + to_register_left = record_1 + + if isinstance(record_2, dict): + to_register_right: AcceptableInputTableType = [record_2] + else: + to_register_right = record_2 + + df_records_left = db_api.register_table( + to_register_left, + f"__splink__compare_records_left_{uid}", + overwrite=True, + ) + df_records_left.templated_name = "__splink__compare_records_left" + + df_records_right = db_api.register_table( + to_register_right, + f"__splink__compare_records_right_{uid}", + overwrite=True, + ) + df_records_right.templated_name = "__splink__compare_records_right" + + settings_id = id(settings) + if use_sql_from_cache: + if cached_sql := _sql_cache.get(settings_id, uid): + return db_api._sql_to_splink_dataframe( + cached_sql, + templated_name="__splink__realtime_compare_records", + physical_name=f"__splink__realtime_compare_records_{uid}", + ) + + if not isinstance(settings, SettingsCreator): + settings_creator = SettingsCreator.from_path_or_dict(settings) + else: + settings_creator = settings + + settings_obj = settings_creator.get_settings(db_api.sql_dialect.sql_dialect_str) + + settings_obj._retain_matching_columns = True + settings_obj._retain_intermediate_calculation_columns = True + + pipeline = CTEPipeline([df_records_left, df_records_right]) + + cols_to_select = settings_obj._columns_to_select_for_blocking + + select_expr = ", ".join(cols_to_select) + sql = f""" + select {select_expr}, 0 as match_key + from __splink__compare_records_left as l + cross join __splink__compare_records_right as r + """ + pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked") + + cols_to_select = settings_obj._columns_to_select_for_comparison_vector_values + select_expr = ", ".join(cols_to_select) + sql = f""" + select {select_expr} + from __splink__compare_two_records_blocked + """ + pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors") + + sqls = predict_from_comparison_vectors_sqls_using_settings( + settings_obj, + sql_infinity_expression=db_api.sql_dialect.infinity_expression, + ) + pipeline.enqueue_list_of_sqls(sqls) + + if include_found_by_blocking_rules: + br_col = _select_found_by_blocking_rules(settings_obj) + sql = f""" + select *, {br_col} + from __splink__df_predict + """ + + pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules") + + predictions = db_api.sql_pipeline_to_splink_dataframe(pipeline) + _sql_cache.set(settings_id, predictions.sql_used_to_create, uid) + + return predictions diff --git a/splink/internals/term_frequencies.py b/splink/internals/term_frequencies.py index 250873e1d4..5cb03d9dbc 100644 --- a/splink/internals/term_frequencies.py +++ b/splink/internals/term_frequencies.py @@ -4,7 +4,7 @@ # https://github.com/moj-analytical-services/splink/pull/107 import logging import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from numpy import arange, ceil, floor, log2 from pandas import concat, cut @@ -16,6 +16,7 @@ ) from splink.internals.input_column import InputColumn from splink.internals.pipeline import CTEPipeline +from splink.internals.splink_dataframe import SplinkDataFrame # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -79,33 +80,48 @@ def _join_tf_to_df_concat_sql(linker: Linker) -> str: return sql -def _join_new_table_to_df_concat_with_tf_sql(linker: Linker, new_tablename: str) -> str: +def _join_new_table_to_df_concat_with_tf_sql( + linker: Linker, + input_tablename: str, + input_table: Optional[SplinkDataFrame] = None, +) -> str: """ - Joins any required tf columns onto new_tablename + Joins any required tf columns onto input_tablename This is needed e.g. when using linker.compare_two_records or linker.inference.find_matches_to_new_records in which the user provides new records which need tf adjustments computed """ + tf_cols_already_populated = [] + + if input_table is not None: + tf_cols_already_populated = [ + c.unquote().name + for c in input_table.columns + if c.unquote().name.startswith("tf_") + ] + tf_cols_not_already_populated = [ + c + for c in linker._settings_obj._term_frequency_columns + if c.unquote().tf_name not in tf_cols_already_populated + ] cache = linker._intermediate_table_cache - settings_obj = linker._settings_obj - tf_cols = settings_obj._term_frequency_columns - select_cols = [f"{new_tablename}.*"] + select_cols = [f"{input_tablename}.*"] - for col in tf_cols: + for col in tf_cols_not_already_populated: tbl = colname_to_tf_tablename(col) if tbl in cache: select_cols.append(f"{tbl}.{col.tf_name}") - template = "left join {tbl} on " + new_tablename + ".{col} = {tbl}.{col}" + template = "left join {tbl} on " + input_tablename + ".{col} = {tbl}.{col}" template_with_alias = ( - "left join ({subquery}) as {_as} on " + new_tablename + ".{col} = {_as}.{col}" + "left join ({subquery}) as {_as} on " + input_tablename + ".{col} = {_as}.{col}" ) left_joins = [] - for i, col in enumerate(tf_cols): + for i, col in enumerate(tf_cols_not_already_populated): tbl = colname_to_tf_tablename(col) if tbl in cache: sql = template.format(tbl=tbl, col=col.name) @@ -127,7 +143,7 @@ def _join_new_table_to_df_concat_with_tf_sql(linker: Linker, new_tablename: str) sql = f""" select {select_cols_str} - from {new_tablename} + from {input_tablename} {left_joins_str} """ diff --git a/tests/datasets/fake_1000_from_splink_demos_strip_datetypes.parquet b/tests/datasets/fake_1000_from_splink_demos_strip_datetypes.parquet new file mode 100644 index 0000000000..84563de25f Binary files /dev/null and b/tests/datasets/fake_1000_from_splink_demos_strip_datetypes.parquet differ diff --git a/tests/test_compare_two_records.py b/tests/test_compare_two_records.py new file mode 100644 index 0000000000..3b0ee04de5 --- /dev/null +++ b/tests/test_compare_two_records.py @@ -0,0 +1,167 @@ +import datetime + +import numpy as np +import pandas as pd +import pytest + +import splink.internals.comparison_library as cl +from splink import SettingsCreator +from splink.internals.blocking_rule_library import block_on +from splink.internals.pipeline import CTEPipeline +from splink.internals.vertically_concatenate import compute_df_concat_with_tf + +from .decorator import mark_with_dialects_excluding + + +@mark_with_dialects_excluding("sqlite") +def test_compare_two_records_1(test_helpers, dialect): + # This one tests the following cases + # - User provides a city tf tble + # - But first_name tf table derived from input data + helper = test_helpers[dialect] + Linker = helper.Linker + + df = helper.load_frame_from_parquet( + "./tests/datasets/fake_1000_from_splink_demos_strip_datetypes.parquet" + ) + + settings = SettingsCreator( + link_type="dedupe_only", + comparisons=[ + cl.ExactMatch("first_name").configure(term_frequency_adjustments=True), + cl.ExactMatch("surname"), + cl.DateOfBirthComparison("dob", input_is_string=False), + cl.ExactMatch("city").configure(term_frequency_adjustments=True), + cl.ExactMatch("email"), + ], + blocking_rules_to_generate_predictions=[ + block_on("first_name"), + block_on("surname"), + ], + max_iterations=2, + retain_intermediate_calculation_columns=True, + retain_matching_columns=True, + ) + + linker = Linker(df, settings, **helper.extra_linker_args()) + + city_tf = pd.DataFrame( + [ + {"city": "London", "tf_city": 0.2}, + {"city": "Liverpool", "tf_city": 0.8}, + ] + ) + linker.table_management.register_term_frequency_lookup(city_tf, "city") + + # Compute the df_concat_with_tf so it's cached + pipeline = CTEPipeline() + compute_df_concat_with_tf(linker, pipeline) + + # Test with dictionary inputs + r1 = { + "first_name": "Julia", + "surname": "Taylor", + "dob": datetime.date(2015, 10, 29), + "city": "London", + "email": "hannah88@powers.com", + } + + r2 = { + "first_name": "Julia", + "surname": "Taylor", + "dob": datetime.date(2015, 10, 29), + "city": "London", + "email": "hannah88@powers.com", + } + + res = linker.inference.compare_two_records(r1, r2) + res_pd = res.as_pandas_dataframe() + + # Verify term frequencies match in the comparison result + assert res_pd["tf_city_l"].iloc[0] == 0.2 + assert res_pd["tf_city_r"].iloc[0] == 0.2 + # This is the tf value as derived from the input data + assert pytest.approx(res_pd["tf_first_name_l"].iloc[0]) == np.float64( + 0.00444444444444 + ) + assert pytest.approx(res_pd["tf_first_name_r"].iloc[0]) == np.float64( + 0.00444444444444 + ) + + +@mark_with_dialects_excluding("sqlite") +def test_compare_two_records_2(test_helpers, dialect): + # This one tests the following cases + # - User provides a city and first_name tf tables + # - But specific values provided in input data, which take precedence + + helper = test_helpers[dialect] + Linker = helper.Linker + + df = helper.load_frame_from_parquet( + "./tests/datasets/fake_1000_from_splink_demos_strip_datetypes.parquet" + ) + + settings = SettingsCreator( + link_type="dedupe_only", + comparisons=[ + cl.ExactMatch("first_name").configure(term_frequency_adjustments=True), + cl.ExactMatch("surname"), + cl.DateOfBirthComparison("dob", input_is_string=False), + cl.ExactMatch("city").configure(term_frequency_adjustments=True), + cl.ExactMatch("email"), + ], + blocking_rules_to_generate_predictions=[ + block_on("first_name"), + block_on("surname"), + ], + max_iterations=2, + retain_intermediate_calculation_columns=True, + retain_matching_columns=True, + ) + + linker = Linker(df, settings, **helper.extra_linker_args()) + + city_tf = pd.DataFrame( + [ + {"city": "London", "tf_city": 0.2}, + {"city": "Liverpool", "tf_city": 0.8}, + ] + ) + linker.table_management.register_term_frequency_lookup(city_tf, "city") + + first_name_tf = pd.DataFrame( + [ + {"first_name": "Julia", "tf_first_name": 0.3}, + {"first_name": "Robert", "tf_first_name": 0.8}, + ] + ) + linker.table_management.register_term_frequency_lookup(first_name_tf, "first_name") + + # Test with dictionary inputs + r1 = { + "first_name": "Julia", + "surname": "Taylor", + "dob": datetime.date(2015, 10, 29), + "city": "London", + "email": "hannah88@powers.com", + "tf_city": 0.5, + } + + r2 = { + "first_name": "Julia", + "surname": "Taylor", + "dob": datetime.date(2015, 10, 29), + "city": "London", + "email": "hannah88@powers.com", + "tf_first_name": 0.4, + } + + res = linker.inference.compare_two_records(r1, r2) + res_pd = res.as_pandas_dataframe() + + # Verify term frequencies match in the comparison result + assert res_pd["tf_city_l"].iloc[0] == 0.5 + assert res_pd["tf_city_r"].iloc[0] == 0.2 + assert res_pd["tf_first_name_l"].iloc[0] == 0.3 + assert res_pd["tf_first_name_r"].iloc[0] == 0.4 diff --git a/tests/test_realtime.py b/tests/test_realtime.py new file mode 100644 index 0000000000..c735c0b84a --- /dev/null +++ b/tests/test_realtime.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +import splink.comparison_library as cl +from splink import SettingsCreator, block_on +from splink.internals.realtime import compare_records + +from .decorator import mark_with_dialects_excluding + + +@mark_with_dialects_excluding() +def test_realtime_cache_two_records(test_helpers, dialect): + # Test that you get the same result whether you cache the SQL + # or not with different records + + helper = test_helpers[dialect] + + db_api = helper.extra_linker_args()["db_api"] + + df1 = pd.DataFrame( + [ + { + "unique_id": 0, + "first_name": "Julia ", + "surname": "Taylor", + "city": "London", + "email": "hannah88@powers.com", + "tf_city": 0.2, + "tf_first_name": 0.1, + } + ] + ) + + df2 = pd.DataFrame( + [ + { + "unique_id": 2, + "first_name": "Julia ", + "surname": "Taylor", + "city": "London", + "email": "hannah88@powers.com", + "cluster": 0, + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + ] + ) + + df3 = pd.DataFrame( + [ + { + "unique_id": 4, + "first_name": "Noah", + "surname": "Watson", + "city": "Bolton", + "email": "matthew78@ballard-mcdonald.net", + "cluster": 1, + "tf_city": 0.01, + "tf_first_name": 0.01, + }, + ] + ) + + settings = SettingsCreator( + link_type="dedupe_only", + comparisons=[ + cl.ExactMatch("first_name").configure(term_frequency_adjustments=True), + cl.ExactMatch("surname"), + cl.ExactMatch("city").configure(term_frequency_adjustments=True), + cl.ExactMatch("email"), + ], + blocking_rules_to_generate_predictions=[ + block_on("first_name"), + block_on("surname"), + ], + max_iterations=2, + retain_intermediate_calculation_columns=True, + retain_matching_columns=True, + ) + + res1_2_first = compare_records(df1, df2, settings, db_api).as_record_dict()[0][ + "match_weight" + ] + + res1_2_not_from_cache = compare_records( + df1, df2, settings, db_api, use_sql_from_cache=False + ).as_record_dict()[0]["match_weight"] + + res1_2_from_cache = compare_records( + df1, df2, settings, db_api, use_sql_from_cache=True + ).as_record_dict()[0]["match_weight"] + + assert res1_2_first == pytest.approx(res1_2_not_from_cache) + assert res1_2_first == pytest.approx(res1_2_from_cache) + + res1_3_first = compare_records(df1, df3, settings, db_api).as_record_dict()[0][ + "match_weight" + ] + res1_3_not_from_cache = compare_records( + df1, df3, settings, db_api, use_sql_from_cache=False + ).as_record_dict()[0]["match_weight"] + res1_3_from_cache = compare_records( + df1, df3, settings, db_api, use_sql_from_cache=True + ).as_record_dict()[0]["match_weight"] + + assert res1_3_first == pytest.approx(res1_3_not_from_cache) + assert res1_3_first == pytest.approx(res1_3_from_cache) + + assert res1_2_first != pytest.approx(res1_3_first) + + +@mark_with_dialects_excluding() +def test_realtime_cache_multiple_records(test_helpers, dialect): + # Test that you get the same result whether you cache the SQL + # or not with multiple records in each DataFrame + + helper = test_helpers[dialect] + db_api = helper.extra_linker_args()["db_api"] + + df1 = pd.DataFrame( + [ + { + "unique_id": 0, + "first_name": "Julia", + "surname": "Taylor", + "city": "London", + "email": "hannah88@powers.com", + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + { + "unique_id": 1, + "first_name": "John", + "surname": "Smith", + "city": "Manchester", + "email": "john.smith@email.com", + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + ] + ) + + df2 = pd.DataFrame( + [ + { + "unique_id": 2, + "first_name": "Julia", + "surname": "Taylor", + "city": "London", + "email": "hannah88@powers.com", + "cluster": 0, + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + { + "unique_id": 3, + "first_name": "Jane", + "surname": "Wilson", + "city": "Birmingham", + "email": "jane.w@example.com", + "cluster": 1, + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + ] + ) + + df3 = pd.DataFrame( + [ + { + "unique_id": 4, + "first_name": "Noah", + "surname": "Watson", + "city": "Bolton", + "email": "matthew78@ballard-mcdonald.net", + "cluster": 2, + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + { + "unique_id": 5, + "first_name": "Emma", + "surname": "Brown", + "city": "Leeds", + "email": "emma.b@test.com", + "cluster": 3, + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + { + "unique_id": 6, + "first_name": "Oliver", + "surname": "Davies", + "city": "Bristol", + "email": "oliver.d@example.net", + "cluster": 4, + "tf_city": 0.2, + "tf_first_name": 0.1, + }, + ] + ) + + # Add required columns if they don't exist + for frame in [df1, df2, df3]: + if "tf_city" not in frame.columns: + frame["tf_city"] = 0.2 + if "tf_first_name" not in frame.columns: + frame["tf_first_name"] = 0.1 + if "cluster" not in frame.columns and frame is not df1: + frame["cluster"] = range(len(frame)) + + settings = SettingsCreator( + link_type="dedupe_only", + comparisons=[ + cl.ExactMatch("first_name").configure(term_frequency_adjustments=True), + cl.ExactMatch("surname"), + cl.ExactMatch("city").configure(term_frequency_adjustments=True), + cl.ExactMatch("email"), + ], + blocking_rules_to_generate_predictions=[ + block_on("first_name"), + block_on("surname"), + ], + max_iterations=2, + retain_intermediate_calculation_columns=True, + retain_matching_columns=True, + ) + + # Compare df1 and df2 + res1_2_first = compare_records(df1, df2, settings, db_api).as_pandas_dataframe() + res1_2_not_from_cache = compare_records( + df1, df2, settings, db_api, use_sql_from_cache=False + ).as_pandas_dataframe() + res1_2_from_cache = compare_records( + df1, df2, settings, db_api, use_sql_from_cache=True + ).as_pandas_dataframe() + + # Compare match weights using pandas merge + merged = res1_2_first.merge( + res1_2_not_from_cache, + on=["unique_id_l", "unique_id_r"], + suffixes=("_first", "_not_cache"), + ) + pd.testing.assert_series_equal( + merged["match_weight_first"], + merged["match_weight_not_cache"], + check_names=False, + ) + + merged = res1_2_first.merge( + res1_2_from_cache, + on=["unique_id_l", "unique_id_r"], + suffixes=("_first", "_from_cache"), + ) + pd.testing.assert_series_equal( + merged["match_weight_first"], + merged["match_weight_from_cache"], + check_names=False, + ) + + res1_3_first = compare_records(df1, df3, settings, db_api).as_pandas_dataframe() + res1_3_not_from_cache = compare_records( + df1, df3, settings, db_api, use_sql_from_cache=False + ).as_pandas_dataframe() + res1_3_from_cache = compare_records( + df1, df3, settings, db_api, use_sql_from_cache=True + ).as_pandas_dataframe() + + merged = res1_3_first.merge( + res1_3_not_from_cache, + on=["unique_id_l", "unique_id_r"], + suffixes=("_first", "_not_cache"), + ) + pd.testing.assert_series_equal( + merged["match_weight_first"], + merged["match_weight_not_cache"], + check_names=False, + ) + + merged = res1_3_first.merge( + res1_3_from_cache, + on=["unique_id_l", "unique_id_r"], + suffixes=("_first", "_from_cache"), + ) + pd.testing.assert_series_equal( + merged["match_weight_first"], + merged["match_weight_from_cache"], + check_names=False, + ) + + +@mark_with_dialects_excluding() +def test_realtime_cache_different_settings(test_helpers, dialect): + helper = test_helpers[dialect] + db_api = helper.extra_linker_args()["db_api"] + + df1 = pd.DataFrame( + [ + { + "unique_id": 0, + "first_name": "Julia", + "surname": "Taylor", + "city": "London", + "email": "julia@email.com", + } + ] + ) + + df2 = pd.DataFrame( + [ + { + "unique_id": 1, + "first_name": "Julia", + "surname": "Taylor", + "city": "London", + "email": "bad@address.com", + } + ] + ) + + settings_1 = SettingsCreator( + link_type="dedupe_only", + comparisons=[ + cl.ExactMatch("first_name"), + cl.ExactMatch("surname"), + cl.ExactMatch("city"), + ], + blocking_rules_to_generate_predictions=[block_on("first_name")], + ) + + settings_2 = SettingsCreator( + link_type="dedupe_only", + comparisons=[ + cl.ExactMatch("first_name"), + cl.ExactMatch("surname"), + cl.ExactMatch("email"), + ], + blocking_rules_to_generate_predictions=[block_on("first_name")], + ) + + res1 = compare_records( + df1, df2, settings_1, db_api, use_sql_from_cache=True + ).as_record_dict()[0]["match_weight"] + + res2 = compare_records( + df1, df2, settings_2, db_api, use_sql_from_cache=True + ).as_record_dict()[0]["match_weight"] + + assert res1 != pytest.approx(res2) + + res1_again = compare_records( + df1, df2, settings_1, db_api, use_sql_from_cache=True + ).as_record_dict()[0]["match_weight"] + assert res1 == pytest.approx(res1_again)