Skip to content

Commit

Permalink
Merge pull request #1700 from moj-analytical-services/refactor_blocki…
Browse files Browse the repository at this point in the history
…ngrule_rename_sql_property

BlockingRule: Clarify name of sql property
  • Loading branch information
RobinL authored Nov 7, 2023
2 parents 3c79882 + 2bcad9b commit 5104945
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 53 deletions.
8 changes: 4 additions & 4 deletions docs/demos/tutorials/03_Blocking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,19 @@
"\n",
"blocking_rule_1 = block_on([\"substr(first_name, 1,1)\", \"surname\"])\n",
"count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_1)\n",
"print(f\"Number of comparisons generated by '{blocking_rule_1.sql}': {count:,.0f}\")\n",
"print(f\"Number of comparisons generated by '{blocking_rule_1.blocking_rule_sql}': {count:,.0f}\")\n",
"\n",
"blocking_rule_2 = block_on(\"surname\")\n",
"count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_2)\n",
"print(f\"Number of comparisons generated by '{blocking_rule_2.sql}': {count:,.0f}\")\n",
"print(f\"Number of comparisons generated by '{blocking_rule_2.blocking_rule_sql}': {count:,.0f}\")\n",
"\n",
"blocking_rule_3 = block_on(\"email\")\n",
"count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_3)\n",
"print(f\"Number of comparisons generated by '{blocking_rule_3.sql}': {count:,.0f}\")\n",
"print(f\"Number of comparisons generated by '{blocking_rule_3.blocking_rule_sql}': {count:,.0f}\")\n",
"\n",
"blocking_rule_4 = block_on([\"city\", \"first_name\"])\n",
"count = linker.count_num_comparisons_from_blocking_rule(blocking_rule_4)\n",
"print(f\"Number of comparisons generated by '{blocking_rule_4.sql}': {count:,.0f}\")\n"
"print(f\"Number of comparisons generated by '{blocking_rule_4.blocking_rule_sql}': {count:,.0f}\")\n"
]
},
{
Expand Down
9 changes: 7 additions & 2 deletions splink/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from .block_from_labels import block_from_labels
from .blocking import BlockingRule
from .comparison_vector_values import compute_comparison_vector_values_sql
from .predict import predict_from_comparison_vectors_sqls
from .sql_transform import move_l_r_table_prefix_to_column_suffix

if TYPE_CHECKING:
from .linker import Linker


def truth_space_table_from_labels_with_predictions_sqls(
threshold_actual=0.5, match_weight_round_to_nearest=None
Expand Down Expand Up @@ -143,10 +147,11 @@ def truth_space_table_from_labels_with_predictions_sqls(
return sqls


def _select_found_by_blocking_rules(linker):
def _select_found_by_blocking_rules(linker: "Linker"):
brs = linker._settings_obj._blocking_rules_to_generate_predictions

if brs:
brs = [move_l_r_table_prefix_to_column_suffix(b.blocking_rule) for b in brs]
brs = [move_l_r_table_prefix_to_column_suffix(b.blocking_rule_sql) for b in brs]
brs = [f"(coalesce({b}, false))" for b in brs]
brs = " OR ".join(brs)
br_col = f" ({brs}) "
Expand Down
2 changes: 1 addition & 1 deletion splink/analyse_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def cumulative_comparisons_generated_by_blocking_rules(
for row, br in zip(br_count, brs_as_objs):
out_dict = {
"row_count": row,
"rule": br.blocking_rule,
"rule": br.blocking_rule_sql,
}
if output_chart:
cumulative_sum += row
Expand Down
54 changes: 32 additions & 22 deletions splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlglot import parse_one
from sqlglot.expressions import Join, Column
from sqlglot.optimizer.eliminate_joins import join_condition
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, List
import logging

from .misc import ensure_is_list
Expand Down Expand Up @@ -40,15 +40,20 @@ def blocking_rule_to_obj(br):
class BlockingRule:
def __init__(
self,
blocking_rule: BlockingRule | dict | str,
blocking_rule_sql: str,
salting_partitions=1,
sqlglot_dialect: str = None,
):
if sqlglot_dialect:
self._sql_dialect = sqlglot_dialect

self.blocking_rule = blocking_rule
self.preceding_rules = []
# Temporarily just to see if tests still pass
if not isinstance(blocking_rule_sql, str):
raise ValueError(
f"Blocking rule must be a string, not {type(blocking_rule_sql)}"
)
self.blocking_rule_sql = blocking_rule_sql
self.preceding_rules: List[BlockingRule] = []
self.sqlglot_dialect = sqlglot_dialect
self.salting_partitions = salting_partitions

Expand All @@ -60,40 +65,45 @@ def sql_dialect(self):
def match_key(self):
return len(self.preceding_rules)

@property
def sql(self):
# Wrapper to reveal the underlying SQL
return self.blocking_rule

def add_preceding_rules(self, rules):
rules = ensure_is_list(rules)
self.preceding_rules = rules

@property
def and_not_preceding_rules_sql(self):
if not self.preceding_rules:
return ""
def exclude_pairs_generated_by_this_rule_sql(self):
"""A SQL string specifying how to exclude the results
of THIS blocking rule from subseqent blocking statements,
so that subsequent statements do not produce duplicate pairs
"""

# Note the coalesce function is important here - otherwise
# you filter out any records with nulls in the previous rules
# meaning these comparisons get lost
return f"coalesce(({self.blocking_rule_sql}),false)"

@property
def exclude_pairs_generated_by_all_preceding_rules_sql(self):
"""A SQL string that excludes the results of ALL previous blocking rules from
the pairwise comparisons generated.
"""
if not self.preceding_rules:
return ""
or_clauses = [
f"coalesce(({r.blocking_rule}), false)" for r in self.preceding_rules
br.exclude_pairs_generated_by_this_rule_sql() for br in self.preceding_rules
]
previous_rules = " OR ".join(or_clauses)
return f"AND NOT ({previous_rules})"

@property
def salted_blocking_rules(self):
if self.salting_partitions == 1:
yield self.blocking_rule
yield self.blocking_rule_sql
else:
for n in range(self.salting_partitions):
yield f"{self.blocking_rule} and ceiling(l.__splink_salt * {self.salting_partitions}) = {n+1}" # noqa: E501
yield f"{self.blocking_rule_sql} and ceiling(l.__splink_salt * {self.salting_partitions}) = {n+1}" # noqa: E501

@property
def _parsed_join_condition(self):
br = self.blocking_rule
br = self.blocking_rule_sql
return parse_one("INNER JOIN r", into=Join).on(
br, dialect=self.sqlglot_dialect
) # using sqlglot==11.4.1
Expand Down Expand Up @@ -147,7 +157,7 @@ def as_dict(self):
"The minimal representation of the blocking rule"
output = {}

output["blocking_rule"] = self.blocking_rule
output["blocking_rule"] = self.blocking_rule_sql
output["sql_dialect"] = self.sql_dialect

if self.salting_partitions > 1 and self.sql_dialect == "spark":
Expand All @@ -157,7 +167,7 @@ def as_dict(self):

def _as_completed_dict(self):
if not self.salting_partitions > 1 and self.sql_dialect == "spark":
return self.blocking_rule
return self.blocking_rule_sql
else:
return self.as_dict()

Expand All @@ -166,7 +176,7 @@ def descr(self):
return "Custom" if not hasattr(self, "_description") else self._description

def _abbreviated_sql(self, cutoff=75):
sql = self.blocking_rule
sql = self.blocking_rule_sql
return (sql[:cutoff] + "...") if len(sql) > cutoff else sql

def __repr__(self):
Expand Down Expand Up @@ -312,7 +322,7 @@ def block_using_rules_sqls(linker: Linker):
if apply_salt:
salted_blocking_rules = br.salted_blocking_rules
else:
salted_blocking_rules = [br.blocking_rule]
salted_blocking_rules = [br.blocking_rule_sql]

for salted_br in salted_blocking_rules:
sql = f"""
Expand All @@ -324,7 +334,7 @@ def block_using_rules_sqls(linker: Linker):
inner join {linker._input_tablename_r} as r
on
({salted_br})
{br.and_not_preceding_rules_sql}
{br.exclude_pairs_generated_by_all_preceding_rules_sql}
{where_condition}
"""

Expand Down
6 changes: 3 additions & 3 deletions splink/blocking_rule_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def not_(*brls: BlockingRule | dict | str, salting_partitions: int = 1) -> Block

brls, sql_dialect, salt = _parse_blocking_rules(*brls)
br = brls[0]
blocking_rule = f"NOT ({br.blocking_rule})"
blocking_rule = f"NOT ({br.blocking_rule_sql})"

return BlockingRule(
blocking_rule,
Expand All @@ -314,9 +314,9 @@ def _br_merge(

brs, sql_dialect, salt = _parse_blocking_rules(*brls)
if len(brs) > 1:
conditions = (f"({br.blocking_rule})" for br in brs)
conditions = (f"({br.blocking_rule_sql})" for br in brs)
else:
conditions = (br.blocking_rule for br in brs)
conditions = (br.blocking_rule_sql for br in brs)

blocking_rule = f" {clause} ".join(conditions)

Expand Down
12 changes: 6 additions & 6 deletions splink/em_training_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _training_log_message(self):
else:
mu = "m and u probabilities"

blocking_rule = self._blocking_rule_for_training.blocking_rule
blocking_rule = self._blocking_rule_for_training.blocking_rule_sql

logger.info(
f"Estimating the {mu} of the model by blocking on:\n"
Expand Down Expand Up @@ -176,7 +176,7 @@ def _train(self):
# check that the blocking rule actually generates _some_ record pairs,
# if not give the user a helpful message
if not cvv.as_record_dict(limit=1):
br_sql = f"`{self._blocking_rule_for_training.blocking_rule}`"
br_sql = f"`{self._blocking_rule_for_training.blocking_rule_sql}`"
raise EMTrainingException(
f"Training rule {br_sql} resulted in no record pairs. "
"This means that in the supplied data set "
Expand All @@ -195,7 +195,7 @@ def _train(self):
# in the original (main) setting object
expectation_maximisation(self, cvv)

rule = self._blocking_rule_for_training.blocking_rule
rule = self._blocking_rule_for_training.blocking_rule_sql
training_desc = f"EM, blocked on: {rule}"

# Add m and u values to original settings
Expand Down Expand Up @@ -254,7 +254,7 @@ def _blocking_adjusted_probability_two_random_records_match(self):
comp_levels = self._comparison_levels_to_reverse_blocking_rule
if not comp_levels:
comp_levels = self._original_settings_obj._get_comparison_levels_corresponding_to_training_blocking_rule( # noqa
self._blocking_rule_for_training.blocking_rule
self._blocking_rule_for_training.blocking_rule_sql
)

for cl in comp_levels:
Expand All @@ -271,7 +271,7 @@ def _blocking_adjusted_probability_two_random_records_match(self):
logger.log(
15,
f"\nProb two random records match adjusted for blocking on "
f"{self._blocking_rule_for_training.blocking_rule}: "
f"{self._blocking_rule_for_training.blocking_rule_sql}: "
f"{adjusted_prop_m:.3f}",
)
return adjusted_prop_m
Expand Down Expand Up @@ -411,7 +411,7 @@ def __repr__(self):
for cc in self._comparisons_that_cannot_be_estimated
]
)
blocking_rule = self._blocking_rule_for_training.blocking_rule
blocking_rule = self._blocking_rule_for_training.blocking_rule_sql
return (
f"<EMTrainingSession, blocking on {blocking_rule}, "
f"deactivating comparisons {deactivated_cols}>"
Expand Down
6 changes: 3 additions & 3 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def _populate_probability_two_random_records_match_from_trained_values(self):
15,
"\n"
f"Probability two random records match from trained model blocking on "
f"{em_training_session._blocking_rule_for_training.blocking_rule}: "
f"{em_training_session._blocking_rule_for_training.blocking_rule_sql}: "
f"{training_lambda:,.3f}",
)

Expand Down Expand Up @@ -1630,7 +1630,7 @@ def estimate_parameters_using_expectation_maximisation(
self._initialise_df_concat_with_tf()

# Extract the blocking rule
blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule
blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule_sql

if comparisons_to_deactivate:
# If user provided a string, convert to Comparison object
Expand Down Expand Up @@ -3100,7 +3100,7 @@ def count_num_comparisons_from_blocking_rule(
int: The number of comparisons generated by the blocking rule
"""

blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule
blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule_sql

sql = vertically_concatenate_sql(self)
self._enqueue_sql(sql, "__splink__df_concat")
Expand Down
7 changes: 4 additions & 3 deletions splink/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import logging
from copy import deepcopy
from typing import List

from .blocking import blocking_rule_to_obj
from .blocking import BlockingRule, blocking_rule_to_obj
from .charts import m_u_parameters_chart, match_weights_chart
from .comparison import Comparison
from .comparison_level import ComparisonLevel
Expand Down Expand Up @@ -125,7 +126,7 @@ def _get_additional_columns_to_retain(self):
used_by_brs = []
for br in self._blocking_rules_to_generate_predictions:
used_by_brs.extend(
get_columns_used_from_sql(br.blocking_rule, br.sql_dialect)
get_columns_used_from_sql(br.blocking_rule_sql, br.sql_dialect)
)

used_by_brs = [InputColumn(c) for c in used_by_brs]
Expand Down Expand Up @@ -300,7 +301,7 @@ def _get_comparison_by_output_column_name(self, name):
return cc
raise ValueError(f"No comparison column with name {name}")

def _brs_as_objs(self, brs_as_strings):
def _brs_as_objs(self, brs_as_strings) -> List[BlockingRule]:
brs_as_objs = [blocking_rule_to_obj(br) for br in brs_as_strings]
for n, br in enumerate(brs_as_objs):
br.add_preceding_rules(brs_as_objs[:n])
Expand Down
5 changes: 3 additions & 2 deletions splink/settings_validation/settings_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
from functools import reduce
from operator import and_
from typing import List

import sqlglot

Expand Down Expand Up @@ -49,9 +50,9 @@ def uid(self):
return self.clean_list_of_column_names(uid_as_tree)

@property
def blocking_rules(self):
def blocking_rules(self) -> List[str]:
brs = self.settings_obj._blocking_rules_to_generate_predictions
return [br.blocking_rule for br in brs]
return [br.blocking_rule_sql for br in brs]

@property
def comparisons(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_binary_composition_internals_OR(test_helpers, dialect):
assert br_surname.__repr__() == exp_txt.format("Exact match", em_rule)
assert BlockingRule(em_rule).__repr__() == exp_txt.format("Custom", em_rule)

assert br_surname.blocking_rule == em_rule
assert br_surname.blocking_rule_sql == em_rule
assert br_surname.salting_partitions == 4
assert br_surname.preceding_rules == []

Expand All @@ -40,13 +40,13 @@ def test_binary_composition_internals_OR(test_helpers, dialect):
brl.exact_match_rule("help4"),
]
brs_as_objs = settings_tester._brs_as_objs(brs_as_strings)
brs_as_txt = [blocking_rule_to_obj(br).blocking_rule for br in brs_as_strings]
brs_as_txt = [blocking_rule_to_obj(br).blocking_rule_sql for br in brs_as_strings]

assert brs_as_objs[0].preceding_rules == []

def assess_preceding_rules(settings_brs_index):
br_prec = brs_as_objs[settings_brs_index].preceding_rules
br_prec_txt = [br.blocking_rule for br in br_prec]
br_prec_txt = [br.blocking_rule_sql for br in br_prec]
assert br_prec_txt == brs_as_txt[:settings_brs_index]

assess_preceding_rules(1)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_blocking_rule_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ def binary_composition_internals(clause, comp_fun, brl, dialect):
# Test what happens when only one value is fed
# It should just report the regular outputs of our comparison level func
level = comp_fun(brl.exact_match_rule("tom"))
assert level.blocking_rule == f"l.{q}tom{q} = r.{q}tom{q}"
assert level.blocking_rule_sql == f"l.{q}tom{q} = r.{q}tom{q}"

# Exact match and null level composition
level = comp_fun(
brl.exact_match_rule("first_name"),
brl.exact_match_rule("surname"),
)
exact_match_sql = f"(l.{q}first_name{q} = r.{q}first_name{q}) {clause} (l.{q}surname{q} = r.{q}surname{q})" # noqa: E501
assert level.blocking_rule == exact_match_sql
assert level.blocking_rule_sql == exact_match_sql
# brl.not_(or_(...)) composition
level = brl.not_(
comp_fun(brl.exact_match_rule("first_name"), brl.exact_match_rule("surname")),
)
assert level.blocking_rule == f"NOT ({exact_match_sql})"
assert level.blocking_rule_sql == f"NOT ({exact_match_sql})"

# Check salting outputs
# salting included in the composition function
Expand Down
Loading

0 comments on commit 5104945

Please sign in to comment.