diff --git a/splink/athena/athena_helpers/athena_blocking_rule_imports.py b/splink/athena/athena_helpers/athena_blocking_rule_imports.py deleted file mode 100644 index 7c76a32db1..0000000000 --- a/splink/athena/athena_helpers/athena_blocking_rule_imports.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from functools import partial - -from ...blocking_rules_library import ( - BlockingRule, - exact_match_rule, -) -from ...blocking_rules_library import ( - block_on as _block_on_, -) - -exact_match_rule = partial(exact_match_rule, _sql_dialect="presto") - - -def block_on( - col_names: list[str], - salting_partitions: int = 1, -) -> BlockingRule: - return _block_on_( - exact_match_rule, - col_names, - salting_partitions, - ) - - -block_on.__doc__ = _block_on_.__doc__ diff --git a/splink/athena/blocking_rule_library.py b/splink/athena/blocking_rule_library.py deleted file mode 100644 index 04cb234f8b..0000000000 --- a/splink/athena/blocking_rule_library.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..blocking_rule_composition import ( # noqa: F401 - and_, - not_, - or_, -) -from .athena_helpers.athena_blocking_rule_imports import ( # noqa: F401 - block_on, - exact_match_rule, -) diff --git a/splink/blocking.py b/splink/blocking.py index b1468067c8..7218e89d14 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -19,7 +19,7 @@ from .linker import Linker -def blocking_rule_to_obj(br): +def blocking_rule_to_obj(br) -> BlockingRule: if isinstance(br, BlockingRule): return br elif isinstance(br, dict): diff --git a/splink/blocking_rule_composition.py b/splink/blocking_rule_composition.py deleted file mode 100644 index 73f69e9664..0000000000 --- a/splink/blocking_rule_composition.py +++ /dev/null @@ -1,352 +0,0 @@ -from __future__ import annotations - -import warnings - -from .blocking import BlockingRule, blocking_rule_to_obj -from .comparison_level_composition import _unify_sql_dialects - - -def and_( - *brls: BlockingRule | dict | str, - salting_partitions=1, -) -> BlockingRule: - """Merge BlockingRules using logical "AND". - - Merge multiple BlockingRules into a single BlockingRule by - merging their SQL conditions using a logical "AND". - - - Args: - *brls (BlockingRule | dict | str): BlockingRules or - blocking rules in the string/dictionary format. - salting_partitions (optional, int): Whether to add salting - to the blocking rule. More information on salting can - be found within the docs. Salting is only valid for Spark. - - Examples: - === ":simple-duckdb: DuckDB" - Simple exact rule composition with an `AND` clause - ``` python - import splink.duckdb.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - brl.exact_match_rule("surname") - ) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.duckdb.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)" - ) - ``` - === ":simple-apachespark: Spark" - Simple exact rule composition with an `AND` clause - ``` python - import splink.spark.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - brl.exact_match_rule("surname") - ) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column, with additional salting (spark exclusive) - ``` python - import splink.spark.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - salting_partitions=5 - ) - ``` - === ":simple-amazonaws: Athena" - Simple exact rule composition with an `AND` clause - ``` python - import splink.athena.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - brl.exact_match_rule("surname") - ) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.athena.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - ) - ``` - === ":simple-sqlite: SQLite" - Simple exact rule composition with an `AND` clause - ``` python - import splink.sqlite.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - brl.exact_match_rule("surname") - ) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.sqlite.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - ) - ``` - === "PostgreSQL" - Simple exact rule composition with an `OR` clause - ``` python - import splink.postgres.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - brl.exact_match_rule("surname") - ) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.postgres.blocking_rule_library as brl - brl.and_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - ) - ``` - - Returns: - BlockingRule: A new BlockingRule with the merged - SQL condition - """ - return _br_merge( - *brls, - clause="AND", - salting_partitions=salting_partitions, - ) - - -def or_( - *brls: BlockingRule | dict | str, - salting_partitions: int = 1, -) -> BlockingRule: - """Merge BlockingRules using logical "OR". - - Merge multiple BlockingRules into a single BlockingRule by - merging their SQL conditions using a logical "OR". - - - Args: - *brls (BlockingRule | dict | str): BlockingRules or - blocking rules in the string/dictionary format. - salting_partitions (optional, int): Whether to add salting - to the blocking rule. More information on salting can - be found within the docs. Salting is only valid for Spark. - - Examples: - === ":simple-duckdb: DuckDB" - Simple exact rule composition with an `OR` clause - ``` python - import splink.duckdb.blocking_rule_library as brl - brl.or_(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.duckdb.blocking_rule_library as brl - brl.or_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)" - ) - ``` - === ":simple-apachespark: Spark" - Simple exact rule composition with an `OR` clause - ``` python - import splink.spark.blocking_rule_library as brl - brl.or_(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column, with additional salting (spark exclusive) - ``` python - import splink.spark.blocking_rule_library as brl - brl.or_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - salting_partitions=5 - ) - ``` - === ":simple-amazonaws: Athena" - Simple exact rule composition with an `OR` clause - ``` python - import splink.athena.blocking_rule_library as brl - brl.or_(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.athena.blocking_rule_library as brl - brl.or_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - ) - ``` - === ":simple-sqlite: SQLite" - Simple exact rule composition with an `OR` clause - ``` python - import splink.sqlite.blocking_rule_library as brl - brl.or_(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.sqlite.blocking_rule_library as brl - brl.or_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - ) - ``` - === "PostgreSQL" - Simple exact rule composition with an `OR` clause - ``` python - import splink.postgres.blocking_rule_library as brl - brl.or_(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")) - ``` - Composing a custom rule with an exact match on name and the year - from a date of birth column - ``` python - import splink.postgres.blocking_rule_library as brl - brl.or_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", - ) - ``` - - Returns: - BlockingRule: A new BlockingRule with the merged - SQL condition - """ - return _br_merge( - *brls, - clause="OR", - salting_partitions=salting_partitions, - ) - - -def not_(*brls: BlockingRule | dict | str, salting_partitions: int = 1) -> BlockingRule: - """Invert a BlockingRule using "NOT". - - Returns a BlockingRule with the same SQL condition as the input, - but prefixed with "NOT". - - Args: - *brls (BlockingRule | dict | str): BlockingRules or - blocking rules in the string/dictionary format. - salting_partitions (optional, int): Whether to add salting - to the blocking rule. More information on salting can - be found within the docs. Salting is only valid for Spark. - - Examples: - === ":simple-duckdb: DuckDB" - Block where we do *not* have an exact match on first name - ``` python - import splink.duckdb.blocking_rule_library as brl - brl.not_(brl.exact_match_rule("first_name")) - ``` - === ":simple-apachespark: Spark" - Block where we do *not* have an exact match on first name - ``` python - import splink.spark.blocking_rule_library as brl - brl.not_(brl.exact_match_rule("first_name")) - ``` - === ":simple-amazonaws: Athena" - Block where we do *not* have an exact match on first name - ``` python - import splink.athena.blocking_rule_library as brl - brl.not_(brl.exact_match_rule("first_name")) - ``` - === ":simple-sqlite: SQLite" - Block where we do *not* have an exact match on first name - ``` python - import splink.sqlite.blocking_rule_library as brl - brl.not_(brl.exact_match_rule("first_name")) - ``` - === "PostgreSQL" - Block where we do *not* have an exact match on first name - ``` python - import splink.postgres.blocking_rule_library as brl - brl.not_(brl.exact_match_rule("first_name")) - ``` - - Returns: - BlockingRule: A new BlockingRule with the merged - SQL condition - """ - if len(brls) == 0: - raise TypeError("You must provide at least one BlockingRule") - elif len(brls) > 1: - warnings.warning( - "More than one BlockingRule entered for `NOT` composition. " - "This function only accepts one argument and will only use your " - "first BlockingRule.", - SyntaxWarning, - stacklevel=2, - ) - - brls, sql_dialect, salt = _parse_blocking_rules(*brls) - br = brls[0] - blocking_rule = f"NOT ({br.blocking_rule_sql})" - - br_dict = { - "blocking_rule": blocking_rule, - "sql_dialect": sql_dialect, - } - - if salting_partitions > 1: - salt = salting_partitions - if salt > 1: - br_dict["salting_partitions"] = salt - - return blocking_rule_to_obj(br_dict) - - -def _br_merge( - *brls: BlockingRule | dict | str, - clause: str, - salting_partitions: int = None, -) -> BlockingRule: - if len(brls) == 0: - raise ValueError("You must provide at least one BlockingRule") - - brs, sql_dialect, salt = _parse_blocking_rules(*brls) - if len(brs) > 1: - conditions = (f"({br.blocking_rule_sql})" for br in brs) - else: - conditions = (br.blocking_rule_sql for br in brs) - - blocking_rule = f" {clause} ".join(conditions) - - br_dict = { - "blocking_rule": blocking_rule, - "sql_dialect": sql_dialect, - } - - if salting_partitions > 1: - salt = salting_partitions - if salt > 1: - br_dict["salting_partitions"] = salt - - return blocking_rule_to_obj(br_dict) - - -def _parse_blocking_rules( - *brs: BlockingRule | dict | str, -) -> tuple[list[BlockingRule], str | None]: - brs = [_to_blocking_rule(br) for br in brs] - sql_dialect = _unify_sql_dialects(brs) - salting_partitions = max([getattr(br, "salting_partitions", 1) for br in brs]) - return brs, sql_dialect, salting_partitions - - -def _to_blocking_rule(br): - return blocking_rule_to_obj(br) diff --git a/splink/blocking_rule_creator.py b/splink/blocking_rule_creator.py new file mode 100644 index 0000000000..37c3ab85f6 --- /dev/null +++ b/splink/blocking_rule_creator.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod +from typing import final + +from .blocking import BlockingRule, blocking_rule_to_obj +from .dialects import SplinkDialect + + +class BlockingRuleCreator(ABC): + def __init__(self, salting_partitions=None, arrays_to_explode=None): + self._salting_partitions = salting_partitions + self._arrays_to_explode = arrays_to_explode + + # @property because merged levels need logic to determine salting partitions + @property + def salting_partitions(self): + return self._salting_partitions + + @property + def arrays_to_explode(self): + return self._arrays_to_explode + + @abstractmethod + def create_sql(self, sql_dialect: SplinkDialect) -> str: + pass + + @final + def create_blocking_rule_dict(self, sql_dialect_str: str) -> dict: + sql_dialect = SplinkDialect.from_string(sql_dialect_str) + level_dict = { + "blocking_rule": self.create_sql(sql_dialect), + "sql_dialect": sql_dialect_str, + } + + if self.salting_partitions and self.arrays_to_explode: + raise ValueError("Cannot use both salting_partitions and arrays_to_explode") + + if self.salting_partitions: + level_dict["salting_partitions"] = self.salting_partitions + + if self.arrays_to_explode: + level_dict["arrays_to_explode"] = self.arrays_to_explode + + return level_dict + + @final + def get_blocking_rule(self, sql_dialect_str: str) -> BlockingRule: + return blocking_rule_to_obj(self.create_blocking_rule_dict(sql_dialect_str)) diff --git a/splink/blocking_rule_library.py b/splink/blocking_rule_library.py new file mode 100644 index 0000000000..38ed33d27d --- /dev/null +++ b/splink/blocking_rule_library.py @@ -0,0 +1,176 @@ +from typing import Union, final + +from sqlglot import TokenError, parse_one + +from .blocking_rule_creator import BlockingRuleCreator +from .column_expression import ColumnExpression +from .dialects import SplinkDialect + + +def _translate_sql_string( + sqlglot_base_dialect_sql: str, + to_sqlglot_dialect: str, + from_sqlglot_dialect: str = None, +) -> str: + tree = parse_one(sqlglot_base_dialect_sql, read=from_sqlglot_dialect) + + return tree.sql(dialect=to_sqlglot_dialect) + + +class ExactMatchRule(BlockingRuleCreator): + def __init__( + self, + col_name_or_expr: Union[str, ColumnExpression], + salting_partitions=None, + arrays_to_explode=None, + ): + super().__init__( + salting_partitions=salting_partitions, arrays_to_explode=arrays_to_explode + ) + self.col_expression = ColumnExpression.instantiate_if_str(col_name_or_expr) + + def create_sql(self, sql_dialect: SplinkDialect) -> str: + self.col_expression.sql_dialect = sql_dialect + col = self.col_expression + return f"{col.l_name} = {col.r_name}" + + +class CustomRule(BlockingRuleCreator): + def __init__( + self, + sql_condition: str, + base_dialect_str: str = None, + salting_partitions=None, + arrays_to_explode=None, + ): + super().__init__( + salting_partitions=salting_partitions, arrays_to_explode=arrays_to_explode + ) + self.sql_condition = sql_condition + + self.base_dialect_str = base_dialect_str + + def create_sql(self, sql_dialect: SplinkDialect) -> str: + sql_condition = self.sql_condition + if self.base_dialect_str is not None: + base_dialect = SplinkDialect.from_string(self.base_dialect_str) + # if we are told it is one dialect, but try to create comparison level + # of another, try to translate with sqlglot + if sql_dialect != base_dialect: + base_dialect_sqlglot_name = base_dialect.sqlglot_name + + # as default, translate condition into our dialect + try: + sql_condition = _translate_sql_string( + sql_condition, + sql_dialect.sqlglot_name, + base_dialect_sqlglot_name, + ) + # if we hit a sqlglot error, assume users knows what they are doing, + # e.g. it is something custom / unknown to sqlglot + # error will just appear when they try to use it + except TokenError: + pass + return sql_condition + + +class _Merge(BlockingRuleCreator): + @final + def __init__( + self, + *blocking_rules: Union[BlockingRuleCreator, dict], + salting_partitions=None, + arrays_to_explode=None, + ): + super().__init__( + salting_partitions=salting_partitions, arrays_to_explode=arrays_to_explode + ) + num_levels = len(blocking_rules) + if num_levels == 0: + raise ValueError( + f"Must provide at least one blocking rule to {type(self)}()" + ) + self.blocking_rules = blocking_rules + + @property + def salting_partitions(self): + if ( + hasattr(self, "_salting_partitions") + and self._salting_partitions is not None + ): + return self._salting_partitions + + return max( + [ + br.salting_partitions + for br in self.blocking_rules + if br.salting_partitions is not None + ], + default=None, + ) + + @property + def arrays_to_explode(self): + if hasattr(self, "_arrays_to_explode"): + return self._arrays_to_explode + + if any([br.arrays_to_explode for br in self.blocking_rules]): + raise ValueError("Cannot merge blocking rules with arrays_to_explode") + return None + + @final + def create_sql(self, sql_dialect: SplinkDialect) -> str: + return f" {self._clause} ".join( + map(lambda cl: f"({cl.create_sql(sql_dialect)})", self.blocking_rules) + ) + + +class And(_Merge): + _clause = "AND" + + +class Or(_Merge): + _clause = "OR" + + +class Not(BlockingRuleCreator): + def __init__(self, blocking_rule_creator): + self.blocking_rule_creator = blocking_rule_creator + + @property + def salting_partitions(self): + return self.blocking_rule_creator.salting_partitions + + @property + def arrays_to_explode(self): + if self.blocking_rule_creator.arrays_to_explode: + raise ValueError("Cannot use arrays_to_explode with Not") + return None + + @final + def create_sql(self, sql_dialect: SplinkDialect) -> str: + return f"NOT ({self.blocking_rule_creator.create_sql(sql_dialect)})" + + +def block_on( + *col_names_or_exprs: Union[str, ColumnExpression], + salting_partitions=None, + arrays_to_explode=None, +) -> BlockingRuleCreator: + if isinstance(col_names_or_exprs[0], list): + raise TypeError( + "block_on no longer accepts a list as the first argument. " + "Please pass individual column names or expressions as separate arguments" + ' e.g. block_on("first_name", "dob") not block_on(["first_name", "dob"])' + ) + + if len(col_names_or_exprs) == 1: + br = ExactMatchRule(col_names_or_exprs[0]) + else: + br = And(*[ExactMatchRule(c) for c in col_names_or_exprs]) + + if salting_partitions: + br._salting_partitions = salting_partitions + if arrays_to_explode: + br._arrays_to_explode = arrays_to_explode + return br diff --git a/splink/blocking_rules_library.py b/splink/blocking_rules_library.py deleted file mode 100644 index f23ab9c615..0000000000 --- a/splink/blocking_rules_library.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -import warnings - -import sqlglot - -from .blocking import BlockingRule, blocking_rule_to_obj -from .blocking_rule_composition import and_ -from .misc import ensure_is_list -from .sql_transform import add_quotes_and_table_prefix - - -def exact_match_rule( - col_name: str, - _sql_dialect: str, - salting_partitions: int = None, -) -> BlockingRule: - """Represents an exact match blocking rule. - - **DEPRECATED:** - `exact_match_rule` is deprecated. Please use `block_on` - instead, which acts as a wrapper with additional functionality. - - Args: - col_name (str): Input column name, or a str represent a sql - statement you'd like to match on. For example, `surname` or - `"substr(surname,1,2)"` are both valid. - salting_partitions (optional, int): Whether to add salting - to the blocking rule. More information on salting can - be found within the docs. Salting is currently only valid - for Spark. - """ - warnings.warn( - "`exact_match_rule` is deprecated; use `block_on`", - DeprecationWarning, - stacklevel=2, - ) - - syntax_tree = sqlglot.parse_one(col_name, read=_sql_dialect) - - l_col = add_quotes_and_table_prefix(syntax_tree, "l").sql(_sql_dialect) - r_col = add_quotes_and_table_prefix(syntax_tree, "r").sql(_sql_dialect) - - blocking_rule = f"{l_col} = {r_col}" - - return blocking_rule_to_obj( - { - "blocking_rule": blocking_rule, - "salting_partitions": salting_partitions, - "sql_dialect": _sql_dialect, - } - ) - - -def block_on( - _exact_match, - col_names: list[str], - salting_partitions: int = 1, -) -> BlockingRule: - """The `block_on` function generates blocking rules that facilitate - efficient equi-joins based on the columns or SQL statements - specified in the col_names argument. When multiple columns or - SQL snippets are provided, the function generates a compound - blocking rule, connecting individual match conditions with - "AND" clauses. - - This function is designed for scenarios where you aim to achieve - efficient yet straightforward blocking conditions based on one - or more columns or SQL snippets. - - For more information on the intended use cases of `block_on`, please see - [the following discussion](https://github.com/moj-analytical-services/splink/issues/1376). - - Further information on equi-join conditions can be found - [here](https://moj-analytical-services.github.io/splink/topic_guides/blocking/performance.html) - - This function acts as a shorthand alias for the `brl.and_` syntax: - ```py - import splink.duckdb.blocking_rule_library as brl - brl.and_(brl.exact_match_rule, brl.exact_match_rule, ...) - ``` - - Args: - col_names (list[str]): A list of input columns or sql conditions - you wish to create blocks on. - salting_partitions (optional, int): Whether to add salting - to the blocking rule. More information on salting can - be found within the docs. Salting is only valid for Spark. - - Examples: - === ":simple-duckdb: DuckDB" - ``` python - from splink.duckdb.blocking_rule_library import block_on - block_on("first_name") # check for exact matches on first name - sql = "substr(surname,1,2)" - block_on([sql, "surname"]) - ``` - === ":simple-apachespark: Spark" - ``` python - from splink.spark.blocking_rule_library import block_on - block_on("first_name") # check for exact matches on first name - sql = "substr(surname,1,2)" - block_on([sql, "surname"], salting_partitions=1) - ``` - === ":simple-amazonaws: Athena" - ``` python - from splink.athena.blocking_rule_library import block_on - block_on("first_name") # check for exact matches on first name - sql = "substr(surname,1,2)" - block_on([sql, "surname"]) - ``` - === ":simple-sqlite: SQLite" - ``` python - from splink.sqlite.blocking_rule_library import block_on - block_on("first_name") # check for exact matches on first name - sql = "substr(surname,1,2)" - block_on([sql, "surname"]) - ``` - === "PostgreSQL" - ``` python - from splink.postgres.blocking_rule_library import block_on - block_on("first_name") # check for exact matches on first name - sql = "substr(surname,1,2)" - block_on([sql, "surname"]) - ``` - """ # noqa: E501 - - col_names = ensure_is_list(col_names) - em_rules = [_exact_match(col) for col in col_names] - return and_(*em_rules, salting_partitions=salting_partitions) diff --git a/splink/column_expression.py b/splink/column_expression.py index 06da4059b7..c0a521a37e 100644 --- a/splink/column_expression.py +++ b/splink/column_expression.py @@ -8,7 +8,10 @@ from .dialects import SplinkDialect from .input_column import SqlglotColumnTreeBuilder -from .sql_transform import add_suffix_to_all_column_identifiers +from .sql_transform import ( + add_suffix_to_all_column_identifiers, + add_table_to_all_column_identifiers, +) class ColumnExpression: @@ -211,6 +214,22 @@ def name_r(self) -> str: ) return self.apply_operations(base_name, self.sql_dialect) + @property + def l_name(self) -> str: + sql_expression = self.parse_input_string(self.sql_dialect) + base_name = add_table_to_all_column_identifiers( + sql_expression, "l", self.sql_dialect.sqlglot_name + ) + return self.apply_operations(base_name, self.sql_dialect) + + @property + def r_name(self) -> str: + sql_expression = self.parse_input_string(self.sql_dialect) + base_name = add_table_to_all_column_identifiers( + sql_expression, "r", self.sql_dialect.sqlglot_name + ) + return self.apply_operations(base_name, self.sql_dialect) + @property def output_column_name(self) -> str: allowed_chars = string.ascii_letters + string.digits + "_" diff --git a/splink/duckdb/blocking_rule_library.py b/splink/duckdb/blocking_rule_library.py deleted file mode 100644 index b7830b9137..0000000000 --- a/splink/duckdb/blocking_rule_library.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..blocking_rule_composition import ( # noqa: F401 - and_, - not_, - or_, -) -from .duckdb_helpers.duckdb_blocking_rule_imports import ( # noqa: F401 - block_on, - exact_match_rule, -) diff --git a/splink/duckdb/duckdb_helpers/duckdb_blocking_rule_imports.py b/splink/duckdb/duckdb_helpers/duckdb_blocking_rule_imports.py deleted file mode 100644 index 6afd8c8767..0000000000 --- a/splink/duckdb/duckdb_helpers/duckdb_blocking_rule_imports.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from functools import partial - -from ...blocking_rules_library import ( - BlockingRule, - exact_match_rule, -) -from ...blocking_rules_library import ( - block_on as _block_on_, -) - -exact_match_rule = partial(exact_match_rule, _sql_dialect="duckdb") - - -def block_on( - col_names: list[str], - salting_partitions: int = 1, -) -> BlockingRule: - return _block_on_( - exact_match_rule, - col_names, - salting_partitions, - ) - - -block_on.__doc__ = _block_on_.__doc__ diff --git a/splink/linker.py b/splink/linker.py index 8f8eeac8fb..80b6be0e8d 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -9,7 +9,7 @@ from copy import copy, deepcopy from pathlib import Path from statistics import median -from typing import Dict +from typing import Dict, Union from splink.input_column import InputColumn @@ -39,6 +39,7 @@ blocking_rule_to_obj, materialise_exploded_id_tables, ) +from .blocking_rule_creator import BlockingRuleCreator from .charts import ( accuracy_chart, completeness_chart, @@ -239,6 +240,7 @@ def __init__( # for now we instantiate all the correct types before the validator sees it settings_dict = deepcopy(settings_dict) self._instantiate_comparison_levels(settings_dict) + self._instantiate_blocking_rules(settings_dict) self._validate_settings_components(settings_dict) self._setup_settings_objs(settings_dict) @@ -453,7 +455,6 @@ def _random_sample_sql( def _register_input_tables( self, input_tables, input_aliases ) -> Dict[str, SplinkDataFrame]: - if input_aliases is None: input_table_aliases = [ f"__splink__input_table_{i}" for i, _ in enumerate(input_tables) @@ -552,6 +553,21 @@ def _instantiate_comparison_levels(self, settings_dict): elif isinstance(comparison, ComparisonCreator): comparisons[idx_c] = comparison.get_comparison(dialect) + def _instantiate_blocking_rules(self, settings_dict): + """ + Mutate our settings_dict, so that any BlockingRuleCreator + instances are instead replaced with BlockingRules + """ + dialect = self._sql_dialect + if settings_dict is None: + return + if "blocking_rules_to_generate_predictions" not in settings_dict: + return + brs = settings_dict["blocking_rules_to_generate_predictions"] + for idx_c, br in enumerate(brs): + if isinstance(br, BlockingRuleCreator): + brs[idx_c] = br.create_blocking_rule_dict(dialect) + def _initialise_df_concat(self, materialise=False): cache = self._intermediate_table_cache concat_df = None @@ -1495,7 +1511,7 @@ def estimate_m_from_label_column(self, label_colname: str): def estimate_parameters_using_expectation_maximisation( self, - blocking_rule: str, + blocking_rule: Union[str, BlockingRuleCreator], comparisons_to_deactivate: list[str | Comparison] = None, comparison_levels_to_reverse_blocking_rule: list[ComparisonLevel] = None, estimate_without_term_frequencies: bool = False, @@ -1553,8 +1569,8 @@ def estimate_parameters_using_expectation_maximisation( ``` Args: - blocking_rule (BlockingRule | str): The blocking rule used to generate - pairwise record comparisons. + blocking_rule (BlockingRuleCreator | str): The blocking rule used to + generate pairwise record comparisons. comparisons_to_deactivate (list, optional): By default, splink will analyse the blocking rule provided and estimate the m parameters for all comaprisons except those included in the blocking rule. If @@ -1604,10 +1620,8 @@ def estimate_parameters_using_expectation_maximisation( # Ensure this has been run on the main linker so that it's in the cache # to be used by the training linkers self._initialise_df_concat_with_tf() - - # Extract the blocking rule - # Check it's a BlockingRule (not a SaltedBlockingRule, ExlpodingBlockingRule) - # and raise error if not specfically a BlockingRule + if isinstance(blocking_rule, BlockingRuleCreator): + blocking_rule = blocking_rule.create_blocking_rule_dict(self._sql_dialect) blocking_rule = blocking_rule_to_obj(blocking_rule) if type(blocking_rule) is not BlockingRule: raise TypeError( diff --git a/splink/postgres/blocking_rule_library.py b/splink/postgres/blocking_rule_library.py deleted file mode 100644 index 4eb24e2610..0000000000 --- a/splink/postgres/blocking_rule_library.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..blocking_rule_composition import ( # noqa: F401 - and_, - not_, - or_, -) -from .postgres_helpers.postgres_blocking_rule_imports import ( # noqa: F401 - block_on, - exact_match_rule, -) diff --git a/splink/postgres/postgres_helpers/postgres_blocking_rule_imports.py b/splink/postgres/postgres_helpers/postgres_blocking_rule_imports.py deleted file mode 100644 index e99bc20199..0000000000 --- a/splink/postgres/postgres_helpers/postgres_blocking_rule_imports.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from functools import partial - -from ...blocking_rules_library import ( - BlockingRule, - exact_match_rule, -) -from ...blocking_rules_library import ( - block_on as _block_on_, -) - -exact_match_rule = partial(exact_match_rule, _sql_dialect="postgres") - - -def block_on( - col_names: list[str], - salting_partitions: int = 1, -) -> BlockingRule: - return _block_on_( - exact_match_rule, - col_names, - salting_partitions, - ) - - -block_on.__doc__ = _block_on_.__doc__ diff --git a/splink/spark/blocking_rule_library.py b/splink/spark/blocking_rule_library.py deleted file mode 100644 index a1b339166d..0000000000 --- a/splink/spark/blocking_rule_library.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..blocking_rule_composition import ( # noqa: F401 - and_, - not_, - or_, -) -from .spark_helpers.spark_blocking_rule_imports import ( # noqa: F401 - block_on, - exact_match_rule, -) diff --git a/splink/spark/spark_helpers/spark_blocking_rule_imports.py b/splink/spark/spark_helpers/spark_blocking_rule_imports.py deleted file mode 100644 index 04fbab7b95..0000000000 --- a/splink/spark/spark_helpers/spark_blocking_rule_imports.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from functools import partial - -from ...blocking_rules_library import ( - BlockingRule, - exact_match_rule, -) -from ...blocking_rules_library import ( - block_on as _block_on_, -) - -exact_match_rule = partial(exact_match_rule, _sql_dialect="spark") - - -def block_on( - col_names: list[str], - salting_partitions: int = 1, -) -> BlockingRule: - return _block_on_( - exact_match_rule, - col_names, - salting_partitions, - ) - - -block_on.__doc__ = _block_on_.__doc__ diff --git a/splink/sql_transform.py b/splink/sql_transform.py index c4f31b3ff9..14bd6ef638 100644 --- a/splink/sql_transform.py +++ b/splink/sql_transform.py @@ -132,3 +132,13 @@ def add_suffix_to_all_column_identifiers( identifier.args["this"] = identifier.args["this"] + suffix return tree.sql(dialect=sqlglot_dialect) + + +# TODO: can we get rid of add_quotes_and_table_prefix and use this everywhere instead +def add_table_to_all_column_identifiers( + sql_str: str, table_name: str, sqlglot_dialect: str +): + tree = sqlglot.parse_one(sql_str, dialect=sqlglot_dialect) + for col in tree.find_all(exp.Column): + col.args["table"] = table_name + return tree.sql(dialect=sqlglot_dialect) diff --git a/splink/sqlite/blocking_rule_library.py b/splink/sqlite/blocking_rule_library.py deleted file mode 100644 index d390a2d9cd..0000000000 --- a/splink/sqlite/blocking_rule_library.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..blocking_rule_composition import ( # noqa: F401 - and_, - not_, - or_, -) -from .sqlite_helpers.sqlite_blocking_rule_imports import ( # noqa: F401 - block_on, - exact_match_rule, -) diff --git a/splink/sqlite/sqlite_helpers/sqlite_blocking_rule_imports.py b/splink/sqlite/sqlite_helpers/sqlite_blocking_rule_imports.py deleted file mode 100644 index b5229b29f4..0000000000 --- a/splink/sqlite/sqlite_helpers/sqlite_blocking_rule_imports.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from functools import partial - -from ...blocking_rules_library import ( - BlockingRule, - exact_match_rule, -) -from ...blocking_rules_library import ( - block_on as _block_on_, -) - -exact_match_rule = partial(exact_match_rule, _sql_dialect="sqlite") - - -def block_on( - col_names: list[str], - salting_partitions: int = 1, -) -> BlockingRule: - return _block_on_( - exact_match_rule, - col_names, - salting_partitions, - ) - - -block_on.__doc__ = _block_on_.__doc__ diff --git a/tests/helpers.py b/tests/helpers.py index 3dfde4a60c..57c9c62ac9 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,10 +6,6 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.types import INTEGER, TEXT -import splink.duckdb.blocking_rule_library as brl_duckdb -import splink.postgres.blocking_rule_library as brl_postgres -import splink.spark.blocking_rule_library as brl_spark -import splink.sqlite.blocking_rule_library as brl_sqlite from splink.database_api import DuckDBAPI, PostgresAPI, SparkAPI, SQLiteAPI from splink.linker import Linker @@ -45,11 +41,6 @@ def load_frame_from_csv(self, path): def load_frame_from_parquet(self, path): return pd.read_parquet(path) - @property - @abstractmethod - def brl(self): - pass - class DuckDBTestHelper(TestHelper): @property @@ -63,10 +54,6 @@ def convert_frame(self, df): def date_format(self): return "%Y-%m-%d" - @property - def brl(self): - return brl_duckdb - class SparkTestHelper(TestHelper): def __init__(self, spark_creator_function): @@ -94,10 +81,6 @@ def load_frame_from_parquet(self, path): df.persist() return df - @property - def brl(self): - return brl_spark - class SQLiteTestHelper(TestHelper): _frame_counter = 0 @@ -130,10 +113,6 @@ def load_frame_from_csv(self, path): def load_frame_from_parquet(self, path): return self.convert_frame(super().load_frame_from_parquet(path)) - @property - def brl(self): - return brl_sqlite - class PostgresTestHelper(TestHelper): _frame_counter = 0 @@ -179,10 +158,6 @@ def load_frame_from_csv(self, path): def load_frame_from_parquet(self, path): return self.convert_frame(super().load_frame_from_parquet(path)) - @property - def brl(self): - return brl_postgres - class SplinkTestException(Exception): pass diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index a640cb4d28..b0b6586a52 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -5,9 +5,9 @@ predictions_from_sample_of_pairwise_labels_sql, truth_space_table_from_labels_with_predictions_sqls, ) +from splink.blocking_rule_library import block_on from splink.comparison_library import ExactMatch from splink.database_api import DuckDBAPI -from splink.duckdb.blocking_rule_library import block_on from splink.linker import Linker from .basic_settings import get_settings_dict @@ -40,7 +40,7 @@ def test_scored_labels_table(): ], "blocking_rules_to_generate_predictions": [ "l.surname = r.surname", - block_on("dob"), + block_on("dob").get_blocking_rule("duckdb"), ], } @@ -101,7 +101,7 @@ def test_truth_space_table(): ], "blocking_rules_to_generate_predictions": [ "l.surname = r.surname", - block_on("dob"), + block_on("dob").get_blocking_rule("duckdb"), ], } diff --git a/tests/test_analyse_blocking.py b/tests/test_analyse_blocking.py index a796a31665..60e8b4f079 100644 --- a/tests/test_analyse_blocking.py +++ b/tests/test_analyse_blocking.py @@ -3,6 +3,7 @@ from splink.analyse_blocking import cumulative_comparisons_generated_by_blocking_rules from splink.blocking import BlockingRule +from splink.blocking_rule_library import CustomRule, Or, block_on from splink.database_api import DuckDBAPI from splink.linker import Linker @@ -14,7 +15,6 @@ def test_analyse_blocking_slow_methodology(test_helpers, dialect): helper = test_helpers[dialect] Linker = helper.Linker - brl = helper.brl df_1 = pd.DataFrame( [ @@ -79,7 +79,7 @@ def test_analyse_blocking_slow_methodology(test_helpers, dialect): assert res == 1 - rule = brl.block_on(["first_name", "surname"]) + rule = block_on("first_name", "surname").get_blocking_rule(dialect) res = linker.count_num_comparisons_from_blocking_rule( rule, ) @@ -103,7 +103,6 @@ def test_blocking_records_accuracy(test_helpers, dialect): helper = test_helpers[dialect] Linker = helper.Linker - brl = helper.brl # resolve an issue w/ pyspark nulls @@ -157,8 +156,8 @@ def test_blocking_records_accuracy(test_helpers, dialect): ) blocking_rules = [ - brl.block_on("first_name"), - brl.block_on(["first_name", "surname"]), + block_on("first_name").get_blocking_rule(dialect), + block_on("first_name", "surname").get_blocking_rule(dialect), "l.dob = r.dob", ] @@ -189,9 +188,10 @@ def test_blocking_records_accuracy(test_helpers, dialect): blocking_rules = [ "l.surname = r.surname", # 2l:2r, - brl.or_( - brl.block_on("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", + Or( + block_on("first_name"), CustomRule("substr(l.dob,1,4) = substr(r.dob,1,4)") + ).get_blocking_rule( + dialect ), # 1r:1r, 1l:2l, 1l:2r "l.surname = r.surname", ] @@ -210,9 +210,11 @@ def test_blocking_records_accuracy(test_helpers, dialect): blocking_rules = [ "l.surname = r.surname", # 2l:2r, - brl.or_( - brl.exact_match_rule("first_name"), - "substr(l.dob,1,4) = substr(r.dob,1,4)", + Or( + block_on("first_name"), + CustomRule("substr(l.dob,1,4) = substr(r.dob,1,4)"), + ).get_blocking_rule( + dialect ), # 1l:1r, 1l:2r "l.surname = r.surname", ] @@ -441,7 +443,7 @@ def test_blocking_rule_accepts_different_dialects(): def test_cumulative_br_funs(test_helpers, dialect): helper = test_helpers[dialect] Linker = helper.Linker - brl = helper.brl + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") linker = Linker(df, get_settings_dict(), **helper.extra_linker_args()) @@ -449,17 +451,20 @@ def test_cumulative_br_funs(test_helpers, dialect): linker.cumulative_comparisons_from_blocking_rules_records( [ "l.first_name = r.first_name", - brl.block_on("surname"), + block_on("surname").get_blocking_rule(dialect), ] ) linker.cumulative_num_comparisons_from_blocking_rules_chart( [ "l.first_name = r.first_name", - brl.block_on("surname"), + block_on("surname").get_blocking_rule(dialect), ] ) assert ( - linker.count_num_comparisons_from_blocking_rule(brl.block_on("surname")) == 3167 + linker.count_num_comparisons_from_blocking_rule( + block_on("surname").get_blocking_rule(dialect) + ) + == 3167 ) diff --git a/tests/test_blocking.py b/tests/test_blocking.py index 8955b0813a..1354e319af 100644 --- a/tests/test_blocking.py +++ b/tests/test_blocking.py @@ -1,5 +1,7 @@ from splink.blocking import BlockingRule, blocking_rule_to_obj +from splink.blocking_rule_library import block_on from splink.input_column import _get_dialect_quotes +from splink.linker import Linker from splink.settings import Settings from .basic_settings import get_settings_dict @@ -7,12 +9,10 @@ @mark_with_dialects_excluding() -def test_binary_composition_internals_OR(test_helpers, dialect): - helper = test_helpers[dialect] - brl = helper.brl - +def test_preceding_blocking_rules(dialect): settings = get_settings_dict() - br_surname = brl.block_on("surname", salting_partitions=4) + br_surname = block_on("surname", salting_partitions=4).get_blocking_rule(dialect) + q, _ = _get_dialect_quotes(dialect) em_rule = f"l.{q}surname{q} = r.{q}surname{q}" @@ -21,8 +21,8 @@ def test_binary_composition_internals_OR(test_helpers, dialect): assert br_surname.preceding_rules == [] preceding_rules = [ - brl.block_on("first_name"), - brl.block_on(["dob"]), + block_on("first_name").get_blocking_rule(dialect), + block_on("dob").get_blocking_rule(dialect), ] br_surname.add_preceding_rules(preceding_rules) assert br_surname.preceding_rules == preceding_rules @@ -34,7 +34,7 @@ def test_binary_composition_internals_OR(test_helpers, dialect): BlockingRule("l.help = r.help"), "l.help2 = r.help2", {"blocking_rule": "l.help3 = r.help3", "salting_partitions": 3}, - brl.block_on("help4"), + block_on("help4").get_blocking_rule(dialect), ] brs_as_objs = settings_tester._brs_as_objs(brs_as_strings) brs_as_txt = [blocking_rule_to_obj(br).blocking_rule_sql for br in brs_as_strings] @@ -54,23 +54,22 @@ def assess_preceding_rules(settings_brs_index): @mark_with_dialects_excluding() def test_simple_end_to_end(test_helpers, dialect): helper = test_helpers[dialect] - Linker = helper.Linker - brl = helper.brl + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") settings = get_settings_dict() settings["blocking_rules_to_generate_predictions"] = [ - brl.block_on(["first_name", "surname"]), - brl.block_on("dob"), + block_on("first_name", "surname"), + block_on("dob"), ] linker = Linker(df, settings, **helper.extra_linker_args()) - linker.estimate_u_using_random_sampling(target_rows=1e5) + linker.estimate_u_using_random_sampling(target_rows=1e3) - blocking_rule = brl.block_on(["first_name", "surname"]) + blocking_rule = block_on("first_name", "surname") linker.estimate_parameters_using_expectation_maximisation(blocking_rule) - linker.estimate_parameters_using_expectation_maximisation(brl.block_on("dob")) + linker.estimate_parameters_using_expectation_maximisation(block_on("dob")) linker.predict() diff --git a/tests/test_blocking_rule_composition.py b/tests/test_blocking_rule_composition.py index 37bdd0e475..deaa7b40cc 100644 --- a/tests/test_blocking_rule_composition.py +++ b/tests/test_blocking_rule_composition.py @@ -1,45 +1,48 @@ import pytest +from splink.blocking_rule_library import And, CustomRule, Not, Or, block_on from splink.input_column import _get_dialect_quotes from .decorator import mark_with_dialects_excluding -def binary_composition_internals(clause, comp_fun, brl, dialect): +def binary_composition_internals(clause, comp_fun, dialect): q, _ = _get_dialect_quotes(dialect) # Test what happens when only one value is fed # It should just report the regular outputs of our comparison level func - level = comp_fun(brl.block_on("tom")) - assert level.blocking_rule_sql == f"l.{q}tom{q} = r.{q}tom{q}" + level = comp_fun(block_on("tom")).get_blocking_rule(dialect) + assert level.blocking_rule_sql == f"(l.{q}tom{q} = r.{q}tom{q})" # Exact match and null level composition level = comp_fun( - brl.block_on("first_name"), - brl.block_on("surname"), - ) + block_on("first_name"), + block_on("surname"), + ).get_blocking_rule(dialect) exact_match_sql = f"(l.{q}first_name{q} = r.{q}first_name{q}) {clause} (l.{q}surname{q} = r.{q}surname{q})" # noqa: E501 assert level.blocking_rule_sql == exact_match_sql - # brl.not_(or_(...)) composition - level = brl.not_( - comp_fun(brl.block_on("first_name"), brl.block_on("surname")), - ) + # not_(or_(...)) composition + level = Not( + comp_fun(block_on("first_name"), block_on("surname")), + ).get_blocking_rule(dialect) assert level.blocking_rule_sql == f"NOT ({exact_match_sql})" # Check salting outputs # salting included in the composition function - salt = comp_fun( - "l.help2 = r.help2", - {"blocking_rule": "l.help3 = r.help3", "salting_partitions": 10}, - brl.block_on("help4"), - salting_partitions=4, + salt = ( + comp_fun( + CustomRule("l.help2 = r.help2"), + CustomRule("l.help3 = r.help3", salting_partitions=10), + block_on("help4"), + salting_partitions=4, + ).get_blocking_rule(dialect) ).salting_partitions # salting included in one of the levels salt_2 = comp_fun( - "l.help2 = r.help2", - {"blocking_rule": "l.help3 = r.help3", "salting_partitions": 3}, - brl.block_on("help4"), + CustomRule("l.help2 = r.help2"), + CustomRule("l.help3 = r.help3", salting_partitions=3), + block_on("help4"), ).salting_partitions assert salt == 4 @@ -51,22 +54,18 @@ def binary_composition_internals(clause, comp_fun, brl, dialect): @mark_with_dialects_excluding() def test_binary_composition_internals_OR(test_helpers, dialect): - brl = test_helpers[dialect].brl - binary_composition_internals("OR", brl.or_, brl, dialect) + binary_composition_internals("OR", Or, dialect) @mark_with_dialects_excluding() def test_binary_composition_internals_AND(test_helpers, dialect): - brl = test_helpers[dialect].brl - binary_composition_internals("AND", brl.and_, brl, dialect) + binary_composition_internals("AND", And, dialect) def test_not(): - import splink.duckdb.blocking_rule_library as brl - # Integration test for a simple dictionary blocking rule dob_jan_first = {"blocking_rule": "SUBSTR(dob_std_l, -5) = '01-01'"} - brl.not_(dob_jan_first) + Not(dob_jan_first) with pytest.raises(TypeError): - brl.not_() + Not() diff --git a/tests/test_find_new_matches.py b/tests/test_find_new_matches.py index 85aff7a687..0613090ba1 100644 --- a/tests/test_find_new_matches.py +++ b/tests/test_find_new_matches.py @@ -3,6 +3,7 @@ import pandas as pd import splink.comparison_library as cl +from splink.blocking_rule_library import block_on from .basic_settings import get_settings_dict from .decorator import mark_with_dialects_excluding @@ -75,7 +76,7 @@ def test_tf_tables_init_works(test_helpers, dialect): def test_matches_work(test_helpers, dialect): helper = test_helpers[dialect] Linker = helper.Linker - brl = helper.brl + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") linker = Linker(df, get_settings_dict(), **helper.extra_linker_args()) @@ -83,7 +84,7 @@ def test_matches_work(test_helpers, dialect): # Train our model to get more reasonable outputs... linker.estimate_u_using_random_sampling(max_pairs=1e6) - blocking_rule = brl.block_on(["first_name", "surname"]) + blocking_rule = block_on("first_name", "surname").get_blocking_rule(dialect) linker.estimate_parameters_using_expectation_maximisation(blocking_rule) blocking_rule = "l.dob = r.dob" diff --git a/tests/test_salting_len.py b/tests/test_salting_len.py index 5651acf59a..6d9e3b357e 100644 --- a/tests/test_salting_len.py +++ b/tests/test_salting_len.py @@ -1,6 +1,6 @@ import pytest -import splink.spark.blocking_rule_library as brl +from splink.blocking_rule_library import block_on from splink.linker import Linker from tests.basic_settings import get_settings_dict @@ -60,7 +60,7 @@ def test_salting_spark(spark, spark_api): ] blocking_rules_salted = [ - brl.block_on("surname", salting_partitions=3), + block_on("surname", salting_partitions=3).get_blocking_rule("spark"), {"blocking_rule": "l.first_name = r.first_name", "salting_partitions": 7}, "l.dob = r.dob", ] diff --git a/tests/test_settings_validation.py b/tests/test_settings_validation.py index aed6fea4a7..0493b63aa3 100644 --- a/tests/test_settings_validation.py +++ b/tests/test_settings_validation.py @@ -3,11 +3,11 @@ import pandas as pd import pytest +from splink.blocking_rule_library import block_on from splink.comparison import Comparison from splink.comparison_library import LevenshteinAtThresholds from splink.convert_v2_to_v3 import convert_settings_from_v2_to_v3 from splink.database_api import DuckDBAPI -from splink.duckdb.blocking_rule_library import block_on from splink.exceptions import ErrorLogger from splink.linker import Linker from splink.settings_validation.log_invalid_columns import ( @@ -84,9 +84,9 @@ 'dmetaphone(c."surname", r."surname")': [ InvalidTableNamesLogGenerator({"c.surname"}) ], - block_on(["left", "right"]).blocking_rule_sql: [ - MissingColumnsLogGenerator({"left", "right"}) - ], + block_on("left", "right") + .get_blocking_rule("duckdb") + .blocking_rule_sql: [MissingColumnsLogGenerator({"left", "right"})], }